Compare commits

..

9 Commits

Author SHA1 Message Date
Yuchen Liang
3ee98775df pageserver: implement aligned io buffer
Signed-off-by: Yuchen Liang <yuchen@neon.tech>
2024-10-07 15:57:49 -04:00
Yuchen Liang
f1418cad52 Merge branch 'main' into yuchen/virtual-file-config 2024-10-07 15:15:26 -04:00
Yuchen Liang
a04cfd754b get rid of io_buffer_alignment config (always 512)
Signed-off-by: Yuchen Liang <yuchen@neon.tech>
2024-10-07 12:16:11 -04:00
Yuchen Liang
bc13310e56 Merge branch 'main' into yuchen/virtual-file-config 2024-10-07 11:49:13 -04:00
Yuchen Liang
5c76b2d474 fix put_io_mode to use the correct http endpoint
Signed-off-by: Yuchen Liang <yuchen@neon.tech>
2024-10-01 10:58:47 -04:00
Yuchen Liang
97f7b0b86f simplify virtual file wrapper
Signed-off-by: Yuchen Liang <yuchen@neon.tech>
2024-10-01 08:31:30 -04:00
Yuchen Liang
3a5b44ea53 add set_io_mode option to getpage_latest_lsn
Signed-off-by: Yuchen Liang <yuchen@neon.tech>
2024-10-01 08:16:18 -04:00
Yuchen Liang
95554c7377 fix clippy
Signed-off-by: Yuchen Liang <yuchen@neon.tech>
2024-10-01 07:59:15 -04:00
Yuchen Liang
a85bd88866 pageserver: add direct io config to virtual file
Signed-off-by: Yuchen Liang <yuchen@neon.tech>
2024-09-30 23:54:14 -04:00
272 changed files with 2032 additions and 3852 deletions

View File

@@ -5,7 +5,9 @@
!Cargo.toml
!Makefile
!rust-toolchain.toml
!scripts/combine_control_files.py
!scripts/ninstall.sh
!vm-cgconfig.conf
!docker-compose/run-tests.sh
# Directories
@@ -15,12 +17,15 @@
!compute_tools/
!control_plane/
!libs/
!neon_local/
!pageserver/
!patches/
!pgxn/
!proxy/
!storage_scrubber/
!safekeeper/
!storage_broker/
!storage_controller/
!trace/
!vendor/postgres-*/
!workspace_hack/

View File

@@ -1,41 +0,0 @@
name: Report Workflow Stats
on:
workflow_run:
workflows:
- Add `external` label to issues and PRs created by external users
- Benchmarking
- Build and Test
- Build and Test Locally
- Build build-tools image
- Check Permissions
- Check build-tools image
- Check neon with extra platform builds
- Cloud Regression Test
- Create Release Branch
- Handle `approved-for-ci-run` label
- Lint GitHub Workflows
- Notify Slack channel about upcoming release
- Periodic pagebench performance test on dedicated EC2 machine in eu-central-1 region
- Pin build-tools image
- Prepare benchmarking databases by restoring dumps
- Push images to ACR
- Test Postgres client libraries
- Trigger E2E Tests
- cleanup caches by a branch
types: [completed]
jobs:
gh-workflow-stats:
name: Github Workflow Stats
runs-on: ubuntu-22.04
permissions:
actions: read
steps:
- name: Export GH Workflow Stats
uses: fedordikarev/gh-workflow-stats-action@v0.1.2
with:
DB_URI: ${{ secrets.GH_REPORT_STATS_DB_RW_CONNSTR }}
DB_TABLE: "gh_workflow_stats_neon"
GH_TOKEN: ${{ secrets.GITHUB_TOKEN }}
GH_RUN_ID: ${{ github.event.workflow_run.id }}

18
Cargo.lock generated
View File

@@ -1820,7 +1820,6 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b5e6043086bf7973472e0c7dff2142ea0b680d30e18d9cc40f267efbf222bd47"
dependencies = [
"base16ct 0.2.0",
"base64ct",
"crypto-bigint 0.5.5",
"digest",
"ff 0.13.0",
@@ -1830,8 +1829,6 @@ dependencies = [
"pkcs8 0.10.2",
"rand_core 0.6.4",
"sec1 0.7.3",
"serde_json",
"serdect",
"subtle",
"zeroize",
]
@@ -4040,8 +4037,6 @@ dependencies = [
"bytes",
"fallible-iterator",
"postgres-protocol",
"serde",
"serde_json",
]
[[package]]
@@ -5261,7 +5256,6 @@ dependencies = [
"der 0.7.8",
"generic-array",
"pkcs8 0.10.2",
"serdect",
"subtle",
"zeroize",
]
@@ -5516,16 +5510,6 @@ dependencies = [
"syn 2.0.52",
]
[[package]]
name = "serdect"
version = "0.2.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "a84f14a19e9a014bb9f4512488d9829a68e04ecabffb0f9904cd1ace94598177"
dependencies = [
"base16ct 0.2.0",
"serde",
]
[[package]]
name = "sha1"
version = "0.10.5"
@@ -7318,7 +7302,6 @@ dependencies = [
"num-traits",
"once_cell",
"parquet",
"postgres-types",
"prettyplease",
"proc-macro2",
"prost",
@@ -7343,7 +7326,6 @@ dependencies = [
"time",
"time-macros",
"tokio",
"tokio-postgres",
"tokio-stream",
"tokio-util",
"toml_edit",

View File

@@ -168,27 +168,27 @@ postgres-check-%: postgres-%
neon-pg-ext-%: postgres-%
+@echo "Compiling neon $*"
mkdir -p $(POSTGRES_INSTALL_DIR)/build/neon-$*
$(MAKE) PG_CONFIG=$(POSTGRES_INSTALL_DIR)/$*/bin/pg_config COPT='$(COPT)' \
$(MAKE) PG_CONFIG=$(POSTGRES_INSTALL_DIR)/$*/bin/pg_config CFLAGS='$(PG_CFLAGS) $(COPT)' \
-C $(POSTGRES_INSTALL_DIR)/build/neon-$* \
-f $(ROOT_PROJECT_DIR)/pgxn/neon/Makefile install
+@echo "Compiling neon_walredo $*"
mkdir -p $(POSTGRES_INSTALL_DIR)/build/neon-walredo-$*
$(MAKE) PG_CONFIG=$(POSTGRES_INSTALL_DIR)/$*/bin/pg_config COPT='$(COPT)' \
$(MAKE) PG_CONFIG=$(POSTGRES_INSTALL_DIR)/$*/bin/pg_config CFLAGS='$(PG_CFLAGS) $(COPT)' \
-C $(POSTGRES_INSTALL_DIR)/build/neon-walredo-$* \
-f $(ROOT_PROJECT_DIR)/pgxn/neon_walredo/Makefile install
+@echo "Compiling neon_rmgr $*"
mkdir -p $(POSTGRES_INSTALL_DIR)/build/neon-rmgr-$*
$(MAKE) PG_CONFIG=$(POSTGRES_INSTALL_DIR)/$*/bin/pg_config COPT='$(COPT)' \
$(MAKE) PG_CONFIG=$(POSTGRES_INSTALL_DIR)/$*/bin/pg_config CFLAGS='$(PG_CFLAGS) $(COPT)' \
-C $(POSTGRES_INSTALL_DIR)/build/neon-rmgr-$* \
-f $(ROOT_PROJECT_DIR)/pgxn/neon_rmgr/Makefile install
+@echo "Compiling neon_test_utils $*"
mkdir -p $(POSTGRES_INSTALL_DIR)/build/neon-test-utils-$*
$(MAKE) PG_CONFIG=$(POSTGRES_INSTALL_DIR)/$*/bin/pg_config COPT='$(COPT)' \
$(MAKE) PG_CONFIG=$(POSTGRES_INSTALL_DIR)/$*/bin/pg_config CFLAGS='$(PG_CFLAGS) $(COPT)' \
-C $(POSTGRES_INSTALL_DIR)/build/neon-test-utils-$* \
-f $(ROOT_PROJECT_DIR)/pgxn/neon_test_utils/Makefile install
+@echo "Compiling neon_utils $*"
mkdir -p $(POSTGRES_INSTALL_DIR)/build/neon-utils-$*
$(MAKE) PG_CONFIG=$(POSTGRES_INSTALL_DIR)/$*/bin/pg_config COPT='$(COPT)' \
$(MAKE) PG_CONFIG=$(POSTGRES_INSTALL_DIR)/$*/bin/pg_config CFLAGS='$(PG_CFLAGS) $(COPT)' \
-C $(POSTGRES_INSTALL_DIR)/build/neon-utils-$* \
-f $(ROOT_PROJECT_DIR)/pgxn/neon_utils/Makefile install
@@ -220,7 +220,7 @@ neon-pg-clean-ext-%:
walproposer-lib: neon-pg-ext-v17
+@echo "Compiling walproposer-lib"
mkdir -p $(POSTGRES_INSTALL_DIR)/build/walproposer-lib
$(MAKE) PG_CONFIG=$(POSTGRES_INSTALL_DIR)/v17/bin/pg_config COPT='$(COPT)' \
$(MAKE) PG_CONFIG=$(POSTGRES_INSTALL_DIR)/v17/bin/pg_config CFLAGS='$(PG_CFLAGS) $(COPT)' \
-C $(POSTGRES_INSTALL_DIR)/build/walproposer-lib \
-f $(ROOT_PROJECT_DIR)/pgxn/neon/Makefile walproposer-lib
cp $(POSTGRES_INSTALL_DIR)/v17/lib/libpgport.a $(POSTGRES_INSTALL_DIR)/build/walproposer-lib
@@ -333,7 +333,7 @@ postgres-%-pgindent: postgres-%-pg-bsd-indent postgres-%-typedefs.list
# Indent pxgn/neon.
.PHONY: neon-pgindent
neon-pgindent: postgres-v17-pg-bsd-indent neon-pg-ext-v17
$(MAKE) PG_CONFIG=$(POSTGRES_INSTALL_DIR)/v17/bin/pg_config COPT='$(COPT)' \
$(MAKE) PG_CONFIG=$(POSTGRES_INSTALL_DIR)/v17/bin/pg_config CFLAGS='$(PG_CFLAGS) $(COPT)' \
FIND_TYPEDEF=$(ROOT_PROJECT_DIR)/vendor/postgres-v17/src/tools/find_typedef \
INDENT=$(POSTGRES_INSTALL_DIR)/build/v17/src/tools/pg_bsd_indent/pg_bsd_indent \
PGINDENT_SCRIPT=$(ROOT_PROJECT_DIR)/vendor/postgres-v17/src/tools/pgindent/pgindent \

View File

@@ -109,30 +109,13 @@ RUN apt update && \
libcgal-dev libgdal-dev libgmp-dev libmpfr-dev libopenscenegraph-dev libprotobuf-c-dev \
protobuf-c-compiler xsltproc
# Postgis 3.5.0 requires SFCGAL 1.4+
#
# It would be nice to update all versions together, but we must solve the SFCGAL dependency first.
# SFCGAL > 1.3 requires CGAL > 5.2, Bullseye's libcgal-dev is 5.2
# and also we must check backward compatibility with older versions of PostGIS.
#
# Use new version only for v17
RUN case "${PG_VERSION}" in \
"v17") \
export SFCGAL_VERSION=1.4.1 \
export SFCGAL_CHECKSUM=1800c8a26241588f11cddcf433049e9b9aea902e923414d2ecef33a3295626c3 \
;; \
"v14" | "v15" | "v16") \
export SFCGAL_VERSION=1.3.10 \
export SFCGAL_CHECKSUM=4e39b3b2adada6254a7bdba6d297bb28e1a9835a9f879b74f37e2dab70203232 \
;; \
*) \
echo "unexpected PostgreSQL version" && exit 1 \
;; \
esac && \
RUN case "${PG_VERSION}" in "v17") \
mkdir -p /sfcgal && \
wget https://gitlab.com/sfcgal/SFCGAL/-/archive/v${SFCGAL_VERSION}/SFCGAL-v${SFCGAL_VERSION}.tar.gz -O SFCGAL.tar.gz && \
echo "${SFCGAL_CHECKSUM} SFCGAL.tar.gz" | sha256sum --check && \
echo "Postgis doensn't yet support PG17 (needs 3.4.3, if not higher)" && exit 0;; \
esac && \
wget https://gitlab.com/Oslandia/SFCGAL/-/archive/v1.3.10/SFCGAL-v1.3.10.tar.gz -O SFCGAL.tar.gz && \
echo "4e39b3b2adada6254a7bdba6d297bb28e1a9835a9f879b74f37e2dab70203232 SFCGAL.tar.gz" | sha256sum --check && \
mkdir sfcgal-src && cd sfcgal-src && tar xzf ../SFCGAL.tar.gz --strip-components=1 -C . && \
cmake -DCMAKE_BUILD_TYPE=Release . && make -j $(getconf _NPROCESSORS_ONLN) && \
DESTDIR=/sfcgal make install -j $(getconf _NPROCESSORS_ONLN) && \
@@ -140,27 +123,15 @@ RUN case "${PG_VERSION}" in \
ENV PATH="/usr/local/pgsql/bin:$PATH"
# Postgis 3.5.0 supports v17
RUN case "${PG_VERSION}" in \
"v17") \
export POSTGIS_VERSION=3.5.0 \
export POSTGIS_CHECKSUM=ca698a22cc2b2b3467ac4e063b43a28413f3004ddd505bdccdd74c56a647f510 \
;; \
"v14" | "v15" | "v16") \
export POSTGIS_VERSION=3.3.3 \
export POSTGIS_CHECKSUM=74eb356e3f85f14233791013360881b6748f78081cc688ff9d6f0f673a762d13 \
;; \
*) \
echo "unexpected PostgreSQL version" && exit 1 \
;; \
RUN case "${PG_VERSION}" in "v17") \
echo "Postgis doensn't yet support PG17 (needs 3.4.3, if not higher)" && exit 0;; \
esac && \
wget https://download.osgeo.org/postgis/source/postgis-${POSTGIS_VERSION}.tar.gz -O postgis.tar.gz && \
echo "${POSTGIS_CHECKSUM} postgis.tar.gz" | sha256sum --check && \
wget https://download.osgeo.org/postgis/source/postgis-3.3.3.tar.gz -O postgis.tar.gz && \
echo "74eb356e3f85f14233791013360881b6748f78081cc688ff9d6f0f673a762d13 postgis.tar.gz" | sha256sum --check && \
mkdir postgis-src && cd postgis-src && tar xzf ../postgis.tar.gz --strip-components=1 -C . && \
find /usr/local/pgsql -type f | sed 's|^/usr/local/pgsql/||' > /before.txt &&\
./autogen.sh && \
./configure --with-sfcgal=/usr/local/bin/sfcgal-config && \
make -j $(getconf _NPROCESSORS_ONLN) && \
make -j $(getconf _NPROCESSORS_ONLN) install && \
cd extensions/postgis && \
make clean && \
@@ -181,27 +152,11 @@ RUN case "${PG_VERSION}" in \
cp /usr/local/pgsql/share/extension/address_standardizer.control /extensions/postgis && \
cp /usr/local/pgsql/share/extension/address_standardizer_data_us.control /extensions/postgis
# Uses versioned libraries, i.e. libpgrouting-3.4
# and may introduce function signature changes between releases
# i.e. release 3.5.0 has new signature for pg_dijkstra function
#
# Use new version only for v17
# last release v3.6.2 - Mar 30, 2024
RUN case "${PG_VERSION}" in \
"v17") \
export PGROUTING_VERSION=3.6.2 \
export PGROUTING_CHECKSUM=f4a1ed79d6f714e52548eca3bb8e5593c6745f1bde92eb5fb858efd8984dffa2 \
;; \
"v14" | "v15" | "v16") \
export PGROUTING_VERSION=3.4.2 \
export PGROUTING_CHECKSUM=cac297c07d34460887c4f3b522b35c470138760fe358e351ad1db4edb6ee306e \
;; \
*) \
echo "unexpected PostgreSQL version" && exit 1 \
;; \
RUN case "${PG_VERSION}" in "v17") \
echo "v17 extensions are not supported yet. Quit" && exit 0;; \
esac && \
wget https://github.com/pgRouting/pgrouting/archive/v${PGROUTING_VERSION}.tar.gz -O pgrouting.tar.gz && \
echo "${PGROUTING_CHECKSUM} pgrouting.tar.gz" | sha256sum --check && \
wget https://github.com/pgRouting/pgrouting/archive/v3.4.2.tar.gz -O pgrouting.tar.gz && \
echo "cac297c07d34460887c4f3b522b35c470138760fe358e351ad1db4edb6ee306e pgrouting.tar.gz" | sha256sum --check && \
mkdir pgrouting-src && cd pgrouting-src && tar xzf ../pgrouting.tar.gz --strip-components=1 -C . && \
mkdir build && cd build && \
cmake -DCMAKE_BUILD_TYPE=Release .. && \
@@ -260,9 +215,10 @@ FROM build-deps AS h3-pg-build
ARG PG_VERSION
COPY --from=pg-build /usr/local/pgsql/ /usr/local/pgsql/
# not version-specific
# last release v4.1.0 - Jan 18, 2023
RUN mkdir -p /h3/usr/ && \
RUN case "${PG_VERSION}" in "v17") \
mkdir -p /h3/usr/ && \
echo "v17 extensions are not supported yet. Quit" && exit 0;; \
esac && \
wget https://github.com/uber/h3/archive/refs/tags/v4.1.0.tar.gz -O h3.tar.gz && \
echo "ec99f1f5974846bde64f4513cf8d2ea1b8d172d2218ab41803bf6a63532272bc h3.tar.gz" | sha256sum --check && \
mkdir h3-src && cd h3-src && tar xzf ../h3.tar.gz --strip-components=1 -C . && \
@@ -273,9 +229,10 @@ RUN mkdir -p /h3/usr/ && \
cp -R /h3/usr / && \
rm -rf build
# not version-specific
# last release v4.1.3 - Jul 26, 2023
RUN wget https://github.com/zachasme/h3-pg/archive/refs/tags/v4.1.3.tar.gz -O h3-pg.tar.gz && \
RUN case "${PG_VERSION}" in "v17") \
echo "v17 extensions are not supported yet. Quit" && exit 0;; \
esac && \
wget https://github.com/zachasme/h3-pg/archive/refs/tags/v4.1.3.tar.gz -O h3-pg.tar.gz && \
echo "5c17f09a820859ffe949f847bebf1be98511fb8f1bd86f94932512c00479e324 h3-pg.tar.gz" | sha256sum --check && \
mkdir h3-pg-src && cd h3-pg-src && tar xzf ../h3-pg.tar.gz --strip-components=1 -C . && \
export PATH="/usr/local/pgsql/bin:$PATH" && \
@@ -294,10 +251,11 @@ FROM build-deps AS unit-pg-build
ARG PG_VERSION
COPY --from=pg-build /usr/local/pgsql/ /usr/local/pgsql/
# not version-specific
# last release 7.9 - Sep 15, 2024
RUN wget https://github.com/df7cb/postgresql-unit/archive/refs/tags/7.9.tar.gz -O postgresql-unit.tar.gz && \
echo "e46de6245dcc8b2c2ecf29873dbd43b2b346773f31dd5ce4b8315895a052b456 postgresql-unit.tar.gz" | sha256sum --check && \
RUN case "${PG_VERSION}" in "v17") \
echo "v17 extensions are not supported yet. Quit" && exit 0;; \
esac && \
wget https://github.com/df7cb/postgresql-unit/archive/refs/tags/7.7.tar.gz -O postgresql-unit.tar.gz && \
echo "411d05beeb97e5a4abf17572bfcfbb5a68d98d1018918feff995f6ee3bb03e79 postgresql-unit.tar.gz" | sha256sum --check && \
mkdir postgresql-unit-src && cd postgresql-unit-src && tar xzf ../postgresql-unit.tar.gz --strip-components=1 -C . && \
make -j $(getconf _NPROCESSORS_ONLN) PG_CONFIG=/usr/local/pgsql/bin/pg_config && \
make -j $(getconf _NPROCESSORS_ONLN) install PG_CONFIG=/usr/local/pgsql/bin/pg_config && \
@@ -344,10 +302,12 @@ FROM build-deps AS pgjwt-pg-build
ARG PG_VERSION
COPY --from=pg-build /usr/local/pgsql/ /usr/local/pgsql/
# not version-specific
# doesn't use releases, last commit f3d82fd - Mar 2, 2023
RUN wget https://github.com/michelp/pgjwt/archive/f3d82fd30151e754e19ce5d6a06c71c20689ce3d.tar.gz -O pgjwt.tar.gz && \
echo "dae8ed99eebb7593b43013f6532d772b12dfecd55548d2673f2dfd0163f6d2b9 pgjwt.tar.gz" | sha256sum --check && \
# 9742dab1b2f297ad3811120db7b21451bca2d3c9 made on 13/11/2021
RUN case "${PG_VERSION}" in "v17") \
echo "v17 extensions are not supported yet. Quit" && exit 0;; \
esac && \
wget https://github.com/michelp/pgjwt/archive/9742dab1b2f297ad3811120db7b21451bca2d3c9.tar.gz -O pgjwt.tar.gz && \
echo "cfdefb15007286f67d3d45510f04a6a7a495004be5b3aecb12cda667e774203f pgjwt.tar.gz" | sha256sum --check && \
mkdir pgjwt-src && cd pgjwt-src && tar xzf ../pgjwt.tar.gz --strip-components=1 -C . && \
make -j $(getconf _NPROCESSORS_ONLN) install PG_CONFIG=/usr/local/pgsql/bin/pg_config && \
echo 'trusted = true' >> /usr/local/pgsql/share/extension/pgjwt.control
@@ -382,9 +342,10 @@ FROM build-deps AS pg-hashids-pg-build
ARG PG_VERSION
COPY --from=pg-build /usr/local/pgsql/ /usr/local/pgsql/
# not version-specific
# last release v1.2.1 -Jan 12, 2018
RUN wget https://github.com/iCyberon/pg_hashids/archive/refs/tags/v1.2.1.tar.gz -O pg_hashids.tar.gz && \
RUN case "${PG_VERSION}" in "v17") \
echo "v17 extensions are not supported yet. Quit" && exit 0;; \
esac && \
wget https://github.com/iCyberon/pg_hashids/archive/refs/tags/v1.2.1.tar.gz -O pg_hashids.tar.gz && \
echo "74576b992d9277c92196dd8d816baa2cc2d8046fe102f3dcd7f3c3febed6822a pg_hashids.tar.gz" | sha256sum --check && \
mkdir pg_hashids-src && cd pg_hashids-src && tar xzf ../pg_hashids.tar.gz --strip-components=1 -C . && \
make -j $(getconf _NPROCESSORS_ONLN) PG_CONFIG=/usr/local/pgsql/bin/pg_config USE_PGXS=1 && \
@@ -444,9 +405,10 @@ FROM build-deps AS ip4r-pg-build
ARG PG_VERSION
COPY --from=pg-build /usr/local/pgsql/ /usr/local/pgsql/
# not version-specific
# last release v2.4.2 - Jul 29, 2023
RUN wget https://github.com/RhodiumToad/ip4r/archive/refs/tags/2.4.2.tar.gz -O ip4r.tar.gz && \
RUN case "${PG_VERSION}" in "v17") \
echo "v17 extensions are not supported yet. Quit" && exit 0;; \
esac && \
wget https://github.com/RhodiumToad/ip4r/archive/refs/tags/2.4.2.tar.gz -O ip4r.tar.gz && \
echo "0f7b1f159974f49a47842a8ab6751aecca1ed1142b6d5e38d81b064b2ead1b4b ip4r.tar.gz" | sha256sum --check && \
mkdir ip4r-src && cd ip4r-src && tar xzf ../ip4r.tar.gz --strip-components=1 -C . && \
make -j $(getconf _NPROCESSORS_ONLN) PG_CONFIG=/usr/local/pgsql/bin/pg_config && \
@@ -463,9 +425,10 @@ FROM build-deps AS prefix-pg-build
ARG PG_VERSION
COPY --from=pg-build /usr/local/pgsql/ /usr/local/pgsql/
# not version-specific
# last release v1.2.10 - Jul 5, 2023
RUN wget https://github.com/dimitri/prefix/archive/refs/tags/v1.2.10.tar.gz -O prefix.tar.gz && \
RUN case "${PG_VERSION}" in "v17") \
echo "v17 extensions are not supported yet. Quit" && exit 0;; \
esac && \
wget https://github.com/dimitri/prefix/archive/refs/tags/v1.2.10.tar.gz -O prefix.tar.gz && \
echo "4342f251432a5f6fb05b8597139d3ccde8dcf87e8ca1498e7ee931ca057a8575 prefix.tar.gz" | sha256sum --check && \
mkdir prefix-src && cd prefix-src && tar xzf ../prefix.tar.gz --strip-components=1 -C . && \
make -j $(getconf _NPROCESSORS_ONLN) PG_CONFIG=/usr/local/pgsql/bin/pg_config && \
@@ -482,9 +445,10 @@ FROM build-deps AS hll-pg-build
ARG PG_VERSION
COPY --from=pg-build /usr/local/pgsql/ /usr/local/pgsql/
# not version-specific
# last release v2.18 - Aug 29, 2023
RUN wget https://github.com/citusdata/postgresql-hll/archive/refs/tags/v2.18.tar.gz -O hll.tar.gz && \
RUN case "${PG_VERSION}" in "v17") \
echo "v17 extensions are not supported yet. Quit" && exit 0;; \
esac && \
wget https://github.com/citusdata/postgresql-hll/archive/refs/tags/v2.18.tar.gz -O hll.tar.gz && \
echo "e2f55a6f4c4ab95ee4f1b4a2b73280258c5136b161fe9d059559556079694f0e hll.tar.gz" | sha256sum --check && \
mkdir hll-src && cd hll-src && tar xzf ../hll.tar.gz --strip-components=1 -C . && \
make -j $(getconf _NPROCESSORS_ONLN) PG_CONFIG=/usr/local/pgsql/bin/pg_config && \
@@ -695,10 +659,11 @@ FROM build-deps AS pg-roaringbitmap-pg-build
ARG PG_VERSION
COPY --from=pg-build /usr/local/pgsql/ /usr/local/pgsql/
# not version-specific
# last release v0.5.4 - Jun 28, 2022
ENV PATH="/usr/local/pgsql/bin/:$PATH"
RUN wget https://github.com/ChenHuajun/pg_roaringbitmap/archive/refs/tags/v0.5.4.tar.gz -O pg_roaringbitmap.tar.gz && \
RUN case "${PG_VERSION}" in "v17") \
echo "v17 extensions is not supported yet by pg_roaringbitmap. Quit" && exit 0;; \
esac && \
wget https://github.com/ChenHuajun/pg_roaringbitmap/archive/refs/tags/v0.5.4.tar.gz -O pg_roaringbitmap.tar.gz && \
echo "b75201efcb1c2d1b014ec4ae6a22769cc7a224e6e406a587f5784a37b6b5a2aa pg_roaringbitmap.tar.gz" | sha256sum --check && \
mkdir pg_roaringbitmap-src && cd pg_roaringbitmap-src && tar xzf ../pg_roaringbitmap.tar.gz --strip-components=1 -C . && \
make -j $(getconf _NPROCESSORS_ONLN) && \
@@ -715,27 +680,12 @@ FROM build-deps AS pg-semver-pg-build
ARG PG_VERSION
COPY --from=pg-build /usr/local/pgsql/ /usr/local/pgsql/
# Release 0.40.0 breaks backward compatibility with previous versions
# see release note https://github.com/theory/pg-semver/releases/tag/v0.40.0
# Use new version only for v17
#
# last release v0.40.0 - Jul 22, 2024
ENV PATH="/usr/local/pgsql/bin/:$PATH"
RUN case "${PG_VERSION}" in \
"v17") \
export SEMVER_VERSION=0.40.0 \
export SEMVER_CHECKSUM=3e50bcc29a0e2e481e7b6d2bc937cadc5f5869f55d983b5a1aafeb49f5425cfc \
;; \
"v14" | "v15" | "v16") \
export SEMVER_VERSION=0.32.1 \
export SEMVER_CHECKSUM=fbdaf7512026d62eec03fad8687c15ed509b6ba395bff140acd63d2e4fbe25d7 \
;; \
*) \
echo "unexpected PostgreSQL version" && exit 1 \
;; \
RUN case "${PG_VERSION}" in "v17") \
echo "v17 is not supported yet by pg_semver. Quit" && exit 0;; \
esac && \
wget https://github.com/theory/pg-semver/archive/refs/tags/v${SEMVER_VERSION}.tar.gz -O pg_semver.tar.gz && \
echo "${SEMVER_CHECKSUM} pg_semver.tar.gz" | sha256sum --check && \
wget https://github.com/theory/pg-semver/archive/refs/tags/v0.32.1.tar.gz -O pg_semver.tar.gz && \
echo "fbdaf7512026d62eec03fad8687c15ed509b6ba395bff140acd63d2e4fbe25d7 pg_semver.tar.gz" | sha256sum --check && \
mkdir pg_semver-src && cd pg_semver-src && tar xzf ../pg_semver.tar.gz --strip-components=1 -C . && \
make -j $(getconf _NPROCESSORS_ONLN) && \
make -j $(getconf _NPROCESSORS_ONLN) install && \

View File

@@ -1484,28 +1484,6 @@ LIMIT 100",
info!("Pageserver config changed");
}
}
// Gather info about installed extensions
pub fn get_installed_extensions(&self) -> Result<()> {
let connstr = self.connstr.clone();
let rt = tokio::runtime::Builder::new_current_thread()
.enable_all()
.build()
.expect("failed to create runtime");
let result = rt
.block_on(crate::installed_extensions::get_installed_extensions(
connstr,
))
.expect("failed to get installed extensions");
info!(
"{}",
serde_json::to_string(&result).expect("failed to serialize extensions list")
);
Ok(())
}
}
pub fn forward_termination_signal() {

View File

@@ -165,32 +165,6 @@ async fn routes(req: Request<Body>, compute: &Arc<ComputeNode>) -> Response<Body
}
}
// get the list of installed extensions
// currently only used in python tests
// TODO: call it from cplane
(&Method::GET, "/installed_extensions") => {
info!("serving /installed_extensions GET request");
let status = compute.get_status();
if status != ComputeStatus::Running {
let msg = format!(
"invalid compute status for extensions request: {:?}",
status
);
error!(msg);
return Response::new(Body::from(msg));
}
let connstr = compute.connstr.clone();
let res = crate::installed_extensions::get_installed_extensions(connstr).await;
match res {
Ok(res) => render_json(Body::from(serde_json::to_string(&res).unwrap())),
Err(e) => render_json_error(
&format!("could not get list of installed extensions: {}", e),
StatusCode::INTERNAL_SERVER_ERROR,
),
}
}
// download extension files from remote extension storage on demand
(&Method::POST, route) if route.starts_with("/extension_server/") => {
info!("serving {:?} POST request", route);

View File

@@ -53,20 +53,6 @@ paths:
schema:
$ref: "#/components/schemas/ComputeInsights"
/installed_extensions:
get:
tags:
- Info
summary: Get installed extensions.
description: ""
operationId: getInstalledExtensions
responses:
200:
description: List of installed extensions
content:
application/json:
schema:
$ref: "#/components/schemas/InstalledExtensions"
/info:
get:
tags:
@@ -409,24 +395,6 @@ components:
- configuration
example: running
InstalledExtensions:
type: object
properties:
extensions:
description: Contains list of installed extensions.
type: array
items:
type: object
properties:
extname:
type: string
versions:
type: array
items:
type: string
n_databases:
type: integer
#
# Errors
#

View File

@@ -1,80 +0,0 @@
use compute_api::responses::{InstalledExtension, InstalledExtensions};
use std::collections::HashMap;
use std::collections::HashSet;
use url::Url;
use anyhow::Result;
use postgres::{Client, NoTls};
use tokio::task;
/// We don't reuse get_existing_dbs() just for code clarity
/// and to make database listing query here more explicit.
///
/// Limit the number of databases to 500 to avoid excessive load.
fn list_dbs(client: &mut Client) -> Result<Vec<String>> {
// `pg_database.datconnlimit = -2` means that the database is in the
// invalid state
let databases = client
.query(
"SELECT datname FROM pg_catalog.pg_database
WHERE datallowconn
AND datconnlimit <> - 2
LIMIT 500",
&[],
)?
.iter()
.map(|row| {
let db: String = row.get("datname");
db
})
.collect();
Ok(databases)
}
/// Connect to every database (see list_dbs above) and get the list of installed extensions.
/// Same extension can be installed in multiple databases with different versions,
/// we only keep the highest and lowest version across all databases.
pub async fn get_installed_extensions(connstr: Url) -> Result<InstalledExtensions> {
let mut connstr = connstr.clone();
task::spawn_blocking(move || {
let mut client = Client::connect(connstr.as_str(), NoTls)?;
let databases: Vec<String> = list_dbs(&mut client)?;
let mut extensions_map: HashMap<String, InstalledExtension> = HashMap::new();
for db in databases.iter() {
connstr.set_path(db);
let mut db_client = Client::connect(connstr.as_str(), NoTls)?;
let extensions: Vec<(String, String)> = db_client
.query(
"SELECT extname, extversion FROM pg_catalog.pg_extension;",
&[],
)?
.iter()
.map(|row| (row.get("extname"), row.get("extversion")))
.collect();
for (extname, v) in extensions.iter() {
let version = v.to_string();
extensions_map
.entry(extname.to_string())
.and_modify(|e| {
e.versions.insert(version.clone());
// count the number of databases where the extension is installed
e.n_databases += 1;
})
.or_insert(InstalledExtension {
extname: extname.to_string(),
versions: HashSet::from([version.clone()]),
n_databases: 1,
});
}
}
Ok(InstalledExtensions {
extensions: extensions_map.values().cloned().collect(),
})
})
.await?
}

View File

@@ -15,7 +15,6 @@ pub mod catalog;
pub mod compute;
pub mod disk_quota;
pub mod extension_server;
pub mod installed_extensions;
pub mod local_proxy;
pub mod lsn_lease;
mod migration;

View File

@@ -1,6 +1,5 @@
//! Structs representing the JSON formats used in the compute_ctl's HTTP API.
use std::collections::HashSet;
use std::fmt::Display;
use chrono::{DateTime, Utc};
@@ -156,15 +155,3 @@ pub enum ControlPlaneComputeStatus {
// should be able to start with provided spec.
Attached,
}
#[derive(Clone, Debug, Default, Serialize)]
pub struct InstalledExtension {
pub extname: String,
pub versions: HashSet<String>,
pub n_databases: u32, // Number of databases using this extension
}
#[derive(Clone, Debug, Default, Serialize)]
pub struct InstalledExtensions {
pub extensions: Vec<InstalledExtension>,
}

View File

@@ -496,12 +496,26 @@ impl RemoteStorage for AzureBlobStorage {
builder = builder.if_match(IfMatchCondition::NotMatch(etag.to_string()))
}
if let Some((start, end)) = opts.byte_range() {
builder = builder.range(match end {
Some(end) => Range::Range(start..end),
None => Range::RangeFrom(start..),
});
}
self.download_for_builder(builder, cancel).await
}
async fn download_byte_range(
&self,
from: &RemotePath,
start_inclusive: u64,
end_exclusive: Option<u64>,
cancel: &CancellationToken,
) -> Result<Download, DownloadError> {
let blob_client = self.client.blob_client(self.relative_path_to_name(from));
let mut builder = blob_client.get();
let range: Range = if let Some(end_exclusive) = end_exclusive {
(start_inclusive..end_exclusive).into()
} else {
(start_inclusive..).into()
};
builder = builder.range(range);
self.download_for_builder(builder, cancel).await
}

View File

@@ -19,8 +19,7 @@ mod simulate_failures;
mod support;
use std::{
collections::HashMap, fmt::Debug, num::NonZeroU32, ops::Bound, pin::Pin, sync::Arc,
time::SystemTime,
collections::HashMap, fmt::Debug, num::NonZeroU32, pin::Pin, sync::Arc, time::SystemTime,
};
use anyhow::Context;
@@ -163,60 +162,11 @@ pub struct Listing {
}
/// Options for downloads. The default value is a plain GET.
#[derive(Default)]
pub struct DownloadOpts {
/// If given, returns [`DownloadError::Unmodified`] if the object still has
/// the same ETag (using If-None-Match).
pub etag: Option<Etag>,
/// The start of the byte range to download, or unbounded.
pub byte_start: Bound<u64>,
/// The end of the byte range to download, or unbounded. Must be after the
/// start bound.
pub byte_end: Bound<u64>,
}
impl Default for DownloadOpts {
fn default() -> Self {
Self {
etag: Default::default(),
byte_start: Bound::Unbounded,
byte_end: Bound::Unbounded,
}
}
}
impl DownloadOpts {
/// Returns the byte range with inclusive start and exclusive end, or None
/// if unbounded.
pub fn byte_range(&self) -> Option<(u64, Option<u64>)> {
if self.byte_start == Bound::Unbounded && self.byte_end == Bound::Unbounded {
return None;
}
let start = match self.byte_start {
Bound::Excluded(i) => i + 1,
Bound::Included(i) => i,
Bound::Unbounded => 0,
};
let end = match self.byte_end {
Bound::Excluded(i) => Some(i),
Bound::Included(i) => Some(i + 1),
Bound::Unbounded => None,
};
if let Some(end) = end {
assert!(start < end, "range end {end} at or before start {start}");
}
Some((start, end))
}
/// Returns the byte range as an RFC 2616 Range header value with inclusive
/// bounds, or None if unbounded.
pub fn byte_range_header(&self) -> Option<String> {
self.byte_range()
.map(|(start, end)| (start, end.map(|end| end - 1))) // make end inclusive
.map(|(start, end)| match end {
Some(end) => format!("bytes={start}-{end}"),
None => format!("bytes={start}-"),
})
}
}
/// Storage (potentially remote) API to manage its state.
@@ -307,6 +257,21 @@ pub trait RemoteStorage: Send + Sync + 'static {
cancel: &CancellationToken,
) -> Result<Download, DownloadError>;
/// Streams a given byte range of the remote storage entry contents.
///
/// The returned download stream will obey initial timeout and cancellation signal by erroring
/// on whichever happens first. Only one of the reasons will fail the stream, which is usually
/// enough for `tokio::io::copy_buf` usage. If needed the error can be filtered out.
///
/// Returns the metadata, if any was stored with the file previously.
async fn download_byte_range(
&self,
from: &RemotePath,
start_inclusive: u64,
end_exclusive: Option<u64>,
cancel: &CancellationToken,
) -> Result<Download, DownloadError>;
/// Delete a single path from remote storage.
///
/// If the operation fails because of timeout or cancellation, the root cause of the error will be
@@ -460,6 +425,33 @@ impl<Other: RemoteStorage> GenericRemoteStorage<Arc<Other>> {
}
}
pub async fn download_byte_range(
&self,
from: &RemotePath,
start_inclusive: u64,
end_exclusive: Option<u64>,
cancel: &CancellationToken,
) -> Result<Download, DownloadError> {
match self {
Self::LocalFs(s) => {
s.download_byte_range(from, start_inclusive, end_exclusive, cancel)
.await
}
Self::AwsS3(s) => {
s.download_byte_range(from, start_inclusive, end_exclusive, cancel)
.await
}
Self::AzureBlob(s) => {
s.download_byte_range(from, start_inclusive, end_exclusive, cancel)
.await
}
Self::Unreliable(s) => {
s.download_byte_range(from, start_inclusive, end_exclusive, cancel)
.await
}
}
}
/// See [`RemoteStorage::delete`]
pub async fn delete(
&self,
@@ -581,6 +573,20 @@ impl GenericRemoteStorage {
})
}
/// Downloads the storage object into the `to_path` provided.
/// `byte_range` could be specified to dowload only a part of the file, if needed.
pub async fn download_storage_object(
&self,
byte_range: Option<(u64, Option<u64>)>,
from: &RemotePath,
cancel: &CancellationToken,
) -> Result<Download, DownloadError> {
match byte_range {
Some((start, end)) => self.download_byte_range(from, start, end, cancel).await,
None => self.download(from, &DownloadOpts::default(), cancel).await,
}
}
/// The name of the bucket/container/etc.
pub fn bucket_name(&self) -> Option<&str> {
match self {
@@ -654,76 +660,6 @@ impl ConcurrencyLimiter {
mod tests {
use super::*;
/// DownloadOpts::byte_range() should generate (inclusive, exclusive) ranges
/// with optional end bound, or None when unbounded.
#[test]
fn download_opts_byte_range() {
// Consider using test_case or a similar table-driven test framework.
let cases = [
// (byte_start, byte_end, expected)
(Bound::Unbounded, Bound::Unbounded, None),
(Bound::Unbounded, Bound::Included(7), Some((0, Some(8)))),
(Bound::Unbounded, Bound::Excluded(7), Some((0, Some(7)))),
(Bound::Included(3), Bound::Unbounded, Some((3, None))),
(Bound::Included(3), Bound::Included(7), Some((3, Some(8)))),
(Bound::Included(3), Bound::Excluded(7), Some((3, Some(7)))),
(Bound::Excluded(3), Bound::Unbounded, Some((4, None))),
(Bound::Excluded(3), Bound::Included(7), Some((4, Some(8)))),
(Bound::Excluded(3), Bound::Excluded(7), Some((4, Some(7)))),
// 1-sized ranges are fine, 0 aren't and will panic (separate test).
(Bound::Included(3), Bound::Included(3), Some((3, Some(4)))),
(Bound::Included(3), Bound::Excluded(4), Some((3, Some(4)))),
];
for (byte_start, byte_end, expect) in cases {
let opts = DownloadOpts {
byte_start,
byte_end,
..Default::default()
};
let result = opts.byte_range();
assert_eq!(
result, expect,
"byte_start={byte_start:?} byte_end={byte_end:?}"
);
// Check generated HTTP header, which uses an inclusive range.
let expect_header = expect.map(|(start, end)| match end {
Some(end) => format!("bytes={start}-{}", end - 1), // inclusive end
None => format!("bytes={start}-"),
});
assert_eq!(
opts.byte_range_header(),
expect_header,
"byte_start={byte_start:?} byte_end={byte_end:?}"
);
}
}
/// DownloadOpts::byte_range() zero-sized byte range should panic.
#[test]
#[should_panic]
fn download_opts_byte_range_zero() {
DownloadOpts {
byte_start: Bound::Included(3),
byte_end: Bound::Excluded(3),
..Default::default()
}
.byte_range();
}
/// DownloadOpts::byte_range() negative byte range should panic.
#[test]
#[should_panic]
fn download_opts_byte_range_negative() {
DownloadOpts {
byte_start: Bound::Included(3),
byte_end: Bound::Included(2),
..Default::default()
}
.byte_range();
}
#[test]
fn test_object_name() {
let k = RemotePath::new(Utf8Path::new("a/b/c")).unwrap();

View File

@@ -506,29 +506,16 @@ impl RemoteStorage for LocalFs {
return Err(DownloadError::Unmodified);
}
let mut file = fs::OpenOptions::new()
.read(true)
.open(&target_path)
.await
.with_context(|| {
format!("Failed to open source file {target_path:?} to use in the download")
})
.map_err(DownloadError::Other)?;
let mut take = file_metadata.len();
if let Some((start, end)) = opts.byte_range() {
if start > 0 {
file.seek(io::SeekFrom::Start(start))
.await
.context("Failed to seek to the range start in a local storage file")
.map_err(DownloadError::Other)?;
}
if let Some(end) = end {
take = end - start;
}
}
let source = ReaderStream::new(file.take(take));
let source = ReaderStream::new(
fs::OpenOptions::new()
.read(true)
.open(&target_path)
.await
.with_context(|| {
format!("Failed to open source file {target_path:?} to use in the download")
})
.map_err(DownloadError::Other)?,
);
let metadata = self
.read_storage_metadata(&target_path)
@@ -548,6 +535,68 @@ impl RemoteStorage for LocalFs {
})
}
async fn download_byte_range(
&self,
from: &RemotePath,
start_inclusive: u64,
end_exclusive: Option<u64>,
cancel: &CancellationToken,
) -> Result<Download, DownloadError> {
if let Some(end_exclusive) = end_exclusive {
if end_exclusive <= start_inclusive {
return Err(DownloadError::Other(anyhow::anyhow!("Invalid range, start ({start_inclusive}) is not less than end_exclusive ({end_exclusive:?})")));
};
if start_inclusive == end_exclusive.saturating_sub(1) {
return Err(DownloadError::Other(anyhow::anyhow!("Invalid range, start ({start_inclusive}) and end_exclusive ({end_exclusive:?}) difference is zero bytes")));
}
}
let target_path = from.with_base(&self.storage_root);
let file_metadata = file_metadata(&target_path).await?;
let mut source = tokio::fs::OpenOptions::new()
.read(true)
.open(&target_path)
.await
.with_context(|| {
format!("Failed to open source file {target_path:?} to use in the download")
})
.map_err(DownloadError::Other)?;
let len = source
.metadata()
.await
.context("query file length")
.map_err(DownloadError::Other)?
.len();
source
.seek(io::SeekFrom::Start(start_inclusive))
.await
.context("Failed to seek to the range start in a local storage file")
.map_err(DownloadError::Other)?;
let metadata = self
.read_storage_metadata(&target_path)
.await
.map_err(DownloadError::Other)?;
let source = source.take(end_exclusive.unwrap_or(len) - start_inclusive);
let source = ReaderStream::new(source);
let cancel_or_timeout = crate::support::cancel_or_timeout(self.timeout, cancel.clone());
let source = crate::support::DownloadStream::new(cancel_or_timeout, source);
let etag = mock_etag(&file_metadata);
Ok(Download {
metadata,
last_modified: file_metadata
.modified()
.map_err(|e| DownloadError::Other(anyhow::anyhow!(e).context("Reading mtime")))?,
etag,
download_stream: Box::pin(source),
})
}
async fn delete(&self, path: &RemotePath, _cancel: &CancellationToken) -> anyhow::Result<()> {
let file_path = path.with_base(&self.storage_root);
match fs::remove_file(&file_path).await {
@@ -639,7 +688,7 @@ mod fs_tests {
use super::*;
use camino_tempfile::tempdir;
use std::{collections::HashMap, io::Write, ops::Bound};
use std::{collections::HashMap, io::Write};
async fn read_and_check_metadata(
storage: &LocalFs,
@@ -755,12 +804,10 @@ mod fs_tests {
let (first_part_local, second_part_local) = uploaded_bytes.split_at(3);
let first_part_download = storage
.download(
.download_byte_range(
&upload_target,
&DownloadOpts {
byte_end: Bound::Excluded(first_part_local.len() as u64),
..Default::default()
},
0,
Some(first_part_local.len() as u64),
&cancel,
)
.await?;
@@ -776,15 +823,10 @@ mod fs_tests {
);
let second_part_download = storage
.download(
.download_byte_range(
&upload_target,
&DownloadOpts {
byte_start: Bound::Included(first_part_local.len() as u64),
byte_end: Bound::Excluded(
(first_part_local.len() + second_part_local.len()) as u64,
),
..Default::default()
},
first_part_local.len() as u64,
Some((first_part_local.len() + second_part_local.len()) as u64),
&cancel,
)
.await?;
@@ -800,14 +842,7 @@ mod fs_tests {
);
let suffix_bytes = storage
.download(
&upload_target,
&DownloadOpts {
byte_start: Bound::Included(13),
..Default::default()
},
&cancel,
)
.download_byte_range(&upload_target, 13, None, &cancel)
.await?
.download_stream;
let suffix_bytes = aggregate(suffix_bytes).await?;
@@ -815,7 +850,7 @@ mod fs_tests {
assert_eq!(upload_name, suffix);
let all_bytes = storage
.download(&upload_target, &DownloadOpts::default(), &cancel)
.download_byte_range(&upload_target, 0, None, &cancel)
.await?
.download_stream;
let all_bytes = aggregate(all_bytes).await?;
@@ -826,26 +861,48 @@ mod fs_tests {
}
#[tokio::test]
#[should_panic(expected = "at or before start")]
async fn download_file_range_negative() {
let (storage, cancel) = create_storage().unwrap();
async fn download_file_range_negative() -> anyhow::Result<()> {
let (storage, cancel) = create_storage()?;
let upload_name = "upload_1";
let upload_target = upload_dummy_file(&storage, upload_name, None, &cancel)
.await
.unwrap();
let upload_target = upload_dummy_file(&storage, upload_name, None, &cancel).await?;
storage
.download(
let start = 1_000_000_000;
let end = start + 1;
match storage
.download_byte_range(
&upload_target,
&DownloadOpts {
byte_start: Bound::Included(10),
byte_end: Bound::Excluded(10),
..Default::default()
},
start,
Some(end), // exclusive end
&cancel,
)
.await
.unwrap();
{
Ok(_) => panic!("Should not allow downloading wrong ranges"),
Err(e) => {
let error_string = e.to_string();
assert!(error_string.contains("zero bytes"));
assert!(error_string.contains(&start.to_string()));
assert!(error_string.contains(&end.to_string()));
}
}
let start = 10000;
let end = 234;
assert!(start > end, "Should test an incorrect range");
match storage
.download_byte_range(&upload_target, start, Some(end), &cancel)
.await
{
Ok(_) => panic!("Should not allow downloading wrong ranges"),
Err(e) => {
let error_string = e.to_string();
assert!(error_string.contains("Invalid range"));
assert!(error_string.contains(&start.to_string()));
assert!(error_string.contains(&end.to_string()));
}
}
Ok(())
}
#[tokio::test]
@@ -888,12 +945,10 @@ mod fs_tests {
let (first_part_local, _) = uploaded_bytes.split_at(3);
let partial_download_with_metadata = storage
.download(
.download_byte_range(
&upload_target,
&DownloadOpts {
byte_end: Bound::Excluded(first_part_local.len() as u64),
..Default::default()
},
0,
Some(first_part_local.len() as u64),
&cancel,
)
.await?;

View File

@@ -804,7 +804,34 @@ impl RemoteStorage for S3Bucket {
bucket: self.bucket_name.clone(),
key: self.relative_path_to_s3_object(from),
etag: opts.etag.as_ref().map(|e| e.to_string()),
range: opts.byte_range_header(),
range: None,
},
cancel,
)
.await
}
async fn download_byte_range(
&self,
from: &RemotePath,
start_inclusive: u64,
end_exclusive: Option<u64>,
cancel: &CancellationToken,
) -> Result<Download, DownloadError> {
// S3 accepts ranges as https://www.w3.org/Protocols/rfc2616/rfc2616-sec14.html#sec14.35
// and needs both ends to be exclusive
let end_inclusive = end_exclusive.map(|end| end.saturating_sub(1));
let range = Some(match end_inclusive {
Some(end_inclusive) => format!("bytes={start_inclusive}-{end_inclusive}"),
None => format!("bytes={start_inclusive}-"),
});
self.download_object(
GetObjectRequest {
bucket: self.bucket_name.clone(),
key: self.relative_path_to_s3_object(from),
etag: None,
range,
},
cancel,
)

View File

@@ -170,13 +170,28 @@ impl RemoteStorage for UnreliableWrapper {
opts: &DownloadOpts,
cancel: &CancellationToken,
) -> Result<Download, DownloadError> {
// Note: We treat any byte range as an "attempt" of the same operation.
// We don't pay attention to the ranges. That's good enough for now.
self.attempt(RemoteOp::Download(from.clone()))
.map_err(DownloadError::Other)?;
self.inner.download(from, opts, cancel).await
}
async fn download_byte_range(
&self,
from: &RemotePath,
start_inclusive: u64,
end_exclusive: Option<u64>,
cancel: &CancellationToken,
) -> Result<Download, DownloadError> {
// Note: We treat any download_byte_range as an "attempt" of the same
// operation. We don't pay attention to the ranges. That's good enough
// for now.
self.attempt(RemoteOp::Download(from.clone()))
.map_err(DownloadError::Other)?;
self.inner
.download_byte_range(from, start_inclusive, end_exclusive, cancel)
.await
}
async fn delete(&self, path: &RemotePath, cancel: &CancellationToken) -> anyhow::Result<()> {
self.delete_inner(path, true, cancel).await
}

View File

@@ -2,7 +2,6 @@ use anyhow::Context;
use camino::Utf8Path;
use futures::StreamExt;
use remote_storage::{DownloadError, DownloadOpts, ListingMode, ListingObject, RemotePath};
use std::ops::Bound;
use std::sync::Arc;
use std::{collections::HashSet, num::NonZeroU32};
use test_context::test_context;
@@ -294,15 +293,7 @@ async fn upload_download_works(ctx: &mut MaybeEnabledStorage) -> anyhow::Result<
// Full range (end specified)
let dl = ctx
.client
.download(
&path,
&DownloadOpts {
byte_start: Bound::Included(0),
byte_end: Bound::Excluded(len as u64),
..Default::default()
},
&cancel,
)
.download_byte_range(&path, 0, Some(len as u64), &cancel)
.await?;
let buf = download_to_vec(dl).await?;
assert_eq!(&buf, &orig);
@@ -310,15 +301,7 @@ async fn upload_download_works(ctx: &mut MaybeEnabledStorage) -> anyhow::Result<
// partial range (end specified)
let dl = ctx
.client
.download(
&path,
&DownloadOpts {
byte_start: Bound::Included(4),
byte_end: Bound::Excluded(10),
..Default::default()
},
&cancel,
)
.download_byte_range(&path, 4, Some(10), &cancel)
.await?;
let buf = download_to_vec(dl).await?;
assert_eq!(&buf, &orig[4..10]);
@@ -326,15 +309,7 @@ async fn upload_download_works(ctx: &mut MaybeEnabledStorage) -> anyhow::Result<
// partial range (end beyond real end)
let dl = ctx
.client
.download(
&path,
&DownloadOpts {
byte_start: Bound::Included(8),
byte_end: Bound::Excluded(len as u64 * 100),
..Default::default()
},
&cancel,
)
.download_byte_range(&path, 8, Some(len as u64 * 100), &cancel)
.await?;
let buf = download_to_vec(dl).await?;
assert_eq!(&buf, &orig[8..]);
@@ -342,14 +317,7 @@ async fn upload_download_works(ctx: &mut MaybeEnabledStorage) -> anyhow::Result<
// Partial range (end unspecified)
let dl = ctx
.client
.download(
&path,
&DownloadOpts {
byte_start: Bound::Included(4),
..Default::default()
},
&cancel,
)
.download_byte_range(&path, 4, None, &cancel)
.await?;
let buf = download_to_vec(dl).await?;
assert_eq!(&buf, &orig[4..]);
@@ -357,14 +325,7 @@ async fn upload_download_works(ctx: &mut MaybeEnabledStorage) -> anyhow::Result<
// Full range (end unspecified)
let dl = ctx
.client
.download(
&path,
&DownloadOpts {
byte_start: Bound::Included(0),
..Default::default()
},
&cancel,
)
.download_byte_range(&path, 0, None, &cancel)
.await?;
let buf = download_to_vec(dl).await?;
assert_eq!(&buf, &orig);

View File

@@ -79,7 +79,8 @@ pub struct Config {
/// memory.
///
/// The default value of `0.15` means that we *guarantee* sending upscale requests if the
/// cgroup is using more than 85% of total memory.
/// cgroup is using more than 85% of total memory (even if we're *not* separately reserving
/// memory for the file cache).
cgroup_min_overhead_fraction: f64,
cgroup_downscale_threshold_buffer_bytes: u64,
@@ -96,12 +97,24 @@ impl Default for Config {
}
impl Config {
fn cgroup_threshold(&self, total_mem: u64) -> u64 {
// We want our threshold to be met gracefully instead of letting postgres get OOM-killed
// (or if there's room, spilling to swap).
fn cgroup_threshold(&self, total_mem: u64, file_cache_disk_size: u64) -> u64 {
// If the file cache is in tmpfs, then it will count towards shmem usage of the cgroup,
// and thus be non-reclaimable, so we should allow for additional memory usage.
//
// If the file cache sits on disk, our desired stable system state is for it to be fully
// page cached (its contents should only be paged to/from disk in situations where we can't
// upscale fast enough). Page-cached memory is reclaimable, so we need to lower the
// threshold for non-reclaimable memory so we scale up *before* the kernel starts paging
// out the file cache.
let memory_remaining_for_cgroup = total_mem.saturating_sub(file_cache_disk_size);
// Even if we're not separately making room for the file cache (if it's in tmpfs), we still
// want our threshold to be met gracefully instead of letting postgres get OOM-killed.
// So we guarantee that there's at least `cgroup_min_overhead_fraction` of total memory
// remaining above the threshold.
(total_mem as f64 * (1.0 - self.cgroup_min_overhead_fraction)) as u64
let max_threshold = (total_mem as f64 * (1.0 - self.cgroup_min_overhead_fraction)) as u64;
memory_remaining_for_cgroup.min(max_threshold)
}
}
@@ -136,6 +149,11 @@ impl Runner {
let mem = get_total_system_memory();
let mut file_cache_disk_size = 0;
// We need to process file cache initialization before cgroup initialization, so that the memory
// allocated to the file cache is appropriately taken into account when we decide the cgroup's
// memory limits.
if let Some(connstr) = &args.pgconnstr {
info!("initializing file cache");
let config = FileCacheConfig::default();
@@ -166,6 +184,7 @@ impl Runner {
info!("file cache size actually got set to {actual_size}")
}
file_cache_disk_size = actual_size;
state.filecache = Some(file_cache);
}
@@ -188,7 +207,7 @@ impl Runner {
cgroup.watch(hist_tx).await
});
let threshold = state.config.cgroup_threshold(mem);
let threshold = state.config.cgroup_threshold(mem, file_cache_disk_size);
info!(threshold, "set initial cgroup threshold",);
state.cgroup = Some(CgroupState {
@@ -240,7 +259,9 @@ impl Runner {
return Ok((false, status.to_owned()));
}
let new_threshold = self.config.cgroup_threshold(usable_system_memory);
let new_threshold = self
.config
.cgroup_threshold(usable_system_memory, expected_file_cache_size);
let current = last_history.avg_non_reclaimable;
@@ -261,11 +282,13 @@ impl Runner {
// The downscaling has been approved. Downscale the file cache, then the cgroup.
let mut status = vec![];
let mut file_cache_disk_size = 0;
if let Some(file_cache) = &mut self.filecache {
let actual_usage = file_cache
.set_file_cache_size(expected_file_cache_size)
.await
.context("failed to set file cache size")?;
file_cache_disk_size = actual_usage;
let message = format!(
"set file cache size to {} MiB",
bytes_to_mebibytes(actual_usage),
@@ -275,7 +298,9 @@ impl Runner {
}
if let Some(cgroup) = &mut self.cgroup {
let new_threshold = self.config.cgroup_threshold(usable_system_memory);
let new_threshold = self
.config
.cgroup_threshold(usable_system_memory, file_cache_disk_size);
let message = format!(
"set cgroup memory threshold from {} MiB to {} MiB, of new total {} MiB",
@@ -304,6 +329,7 @@ impl Runner {
let new_mem = resources.mem;
let usable_system_memory = new_mem.saturating_sub(self.config.sys_buffer_bytes);
let mut file_cache_disk_size = 0;
if let Some(file_cache) = &mut self.filecache {
let expected_usage = file_cache.config.calculate_cache_size(usable_system_memory);
info!(
@@ -316,6 +342,7 @@ impl Runner {
.set_file_cache_size(expected_usage)
.await
.context("failed to set file cache size")?;
file_cache_disk_size = actual_usage;
if actual_usage != expected_usage {
warn!(
@@ -327,7 +354,9 @@ impl Runner {
}
if let Some(cgroup) = &mut self.cgroup {
let new_threshold = self.config.cgroup_threshold(usable_system_memory);
let new_threshold = self
.config
.cgroup_threshold(usable_system_memory, file_cache_disk_size);
info!(
"set cgroup memory threshold from {} MiB to {} MiB of new total {} MiB",

View File

@@ -704,8 +704,6 @@ async fn timeline_archival_config_handler(
let tenant_shard_id: TenantShardId = parse_request_param(&request, "tenant_shard_id")?;
let timeline_id: TimelineId = parse_request_param(&request, "timeline_id")?;
let ctx = RequestContext::new(TaskKind::MgmtRequest, DownloadBehavior::Warn);
let request_data: TimelineArchivalConfigRequest = json_request(&mut request).await?;
check_permission(&request, Some(tenant_shard_id.tenant_id))?;
let state = get_state(&request);
@@ -716,7 +714,7 @@ async fn timeline_archival_config_handler(
.get_attached_tenant_shard(tenant_shard_id)?;
tenant
.apply_timeline_archival_config(timeline_id, request_data.state, ctx)
.apply_timeline_archival_config(timeline_id, request_data.state)
.await?;
Ok::<_, ApiError>(())
}

View File

@@ -38,7 +38,6 @@ use std::future::Future;
use std::sync::Weak;
use std::time::SystemTime;
use storage_broker::BrokerClientChannel;
use timeline::offload::offload_timeline;
use tokio::io::BufReader;
use tokio::sync::watch;
use tokio::task::JoinSet;
@@ -288,13 +287,9 @@ pub struct Tenant {
/// During timeline creation, we first insert the TimelineId to the
/// creating map, then `timelines`, then remove it from the creating map.
/// **Lock order**: if acquiring both, acquire`timelines` before `timelines_creating`
/// **Lock order**: if acquring both, acquire`timelines` before `timelines_creating`
timelines_creating: std::sync::Mutex<HashSet<TimelineId>>,
/// Possibly offloaded and archived timelines
/// **Lock order**: if acquiring both, acquire`timelines` before `timelines_offloaded`
timelines_offloaded: Mutex<HashMap<TimelineId, Arc<OffloadedTimeline>>>,
// This mutex prevents creation of new timelines during GC.
// Adding yet another mutex (in addition to `timelines`) is needed because holding
// `timelines` mutex during all GC iteration
@@ -489,65 +484,6 @@ impl WalRedoManager {
}
}
pub struct OffloadedTimeline {
pub tenant_shard_id: TenantShardId,
pub timeline_id: TimelineId,
pub ancestor_timeline_id: Option<TimelineId>,
// TODO: once we persist offloaded state, make this lazily constructed
pub remote_client: Arc<RemoteTimelineClient>,
/// Prevent two tasks from deleting the timeline at the same time. If held, the
/// timeline is being deleted. If 'true', the timeline has already been deleted.
pub delete_progress: Arc<tokio::sync::Mutex<DeleteTimelineFlow>>,
}
impl OffloadedTimeline {
fn from_timeline(timeline: &Timeline) -> Self {
Self {
tenant_shard_id: timeline.tenant_shard_id,
timeline_id: timeline.timeline_id,
ancestor_timeline_id: timeline.get_ancestor_timeline_id(),
remote_client: timeline.remote_client.clone(),
delete_progress: timeline.delete_progress.clone(),
}
}
}
#[derive(Clone)]
pub enum TimelineOrOffloaded {
Timeline(Arc<Timeline>),
Offloaded(Arc<OffloadedTimeline>),
}
impl TimelineOrOffloaded {
pub fn tenant_shard_id(&self) -> TenantShardId {
match self {
TimelineOrOffloaded::Timeline(timeline) => timeline.tenant_shard_id,
TimelineOrOffloaded::Offloaded(offloaded) => offloaded.tenant_shard_id,
}
}
pub fn timeline_id(&self) -> TimelineId {
match self {
TimelineOrOffloaded::Timeline(timeline) => timeline.timeline_id,
TimelineOrOffloaded::Offloaded(offloaded) => offloaded.timeline_id,
}
}
pub fn delete_progress(&self) -> &Arc<tokio::sync::Mutex<DeleteTimelineFlow>> {
match self {
TimelineOrOffloaded::Timeline(timeline) => &timeline.delete_progress,
TimelineOrOffloaded::Offloaded(offloaded) => &offloaded.delete_progress,
}
}
pub fn remote_client(&self) -> &Arc<RemoteTimelineClient> {
match self {
TimelineOrOffloaded::Timeline(timeline) => &timeline.remote_client,
TimelineOrOffloaded::Offloaded(offloaded) => &offloaded.remote_client,
}
}
}
#[derive(Debug, thiserror::Error, PartialEq, Eq)]
pub enum GetTimelineError {
#[error("Timeline is shutting down")]
@@ -1470,192 +1406,52 @@ impl Tenant {
}
}
fn check_to_be_archived_has_no_unarchived_children(
timeline_id: TimelineId,
timelines: &std::sync::MutexGuard<'_, HashMap<TimelineId, Arc<Timeline>>>,
) -> Result<(), TimelineArchivalError> {
let children: Vec<TimelineId> = timelines
.iter()
.filter_map(|(id, entry)| {
if entry.get_ancestor_timeline_id() != Some(timeline_id) {
return None;
}
if entry.is_archived() == Some(true) {
return None;
}
Some(*id)
})
.collect();
if !children.is_empty() {
return Err(TimelineArchivalError::HasUnarchivedChildren(children));
}
Ok(())
}
fn check_ancestor_of_to_be_unarchived_is_not_archived(
ancestor_timeline_id: TimelineId,
timelines: &std::sync::MutexGuard<'_, HashMap<TimelineId, Arc<Timeline>>>,
offloaded_timelines: &std::sync::MutexGuard<
'_,
HashMap<TimelineId, Arc<OffloadedTimeline>>,
>,
) -> Result<(), TimelineArchivalError> {
let has_archived_parent =
if let Some(ancestor_timeline) = timelines.get(&ancestor_timeline_id) {
ancestor_timeline.is_archived() == Some(true)
} else if offloaded_timelines.contains_key(&ancestor_timeline_id) {
true
} else {
error!("ancestor timeline {ancestor_timeline_id} not found");
if cfg!(debug_assertions) {
panic!("ancestor timeline {ancestor_timeline_id} not found");
}
return Err(TimelineArchivalError::NotFound);
};
if has_archived_parent {
return Err(TimelineArchivalError::HasArchivedParent(
ancestor_timeline_id,
));
}
Ok(())
}
fn check_to_be_unarchived_timeline_has_no_archived_parent(
timeline: &Arc<Timeline>,
) -> Result<(), TimelineArchivalError> {
if let Some(ancestor_timeline) = timeline.ancestor_timeline() {
if ancestor_timeline.is_archived() == Some(true) {
return Err(TimelineArchivalError::HasArchivedParent(
ancestor_timeline.timeline_id,
));
}
}
Ok(())
}
/// Loads the specified (offloaded) timeline from S3 and attaches it as a loaded timeline
async fn unoffload_timeline(
self: &Arc<Self>,
timeline_id: TimelineId,
ctx: RequestContext,
) -> Result<Arc<Timeline>, TimelineArchivalError> {
let cancel = self.cancel.clone();
let timeline_preload = self
.load_timeline_metadata(timeline_id, self.remote_storage.clone(), cancel)
.await;
let index_part = match timeline_preload.index_part {
Ok(index_part) => {
debug!("remote index part exists for timeline {timeline_id}");
index_part
}
Err(DownloadError::NotFound) => {
error!(%timeline_id, "index_part not found on remote");
return Err(TimelineArchivalError::NotFound);
}
Err(e) => {
// Some (possibly ephemeral) error happened during index_part download.
warn!(%timeline_id, "Failed to load index_part from remote storage, failed creation? ({e})");
return Err(TimelineArchivalError::Other(
anyhow::Error::new(e).context("downloading index_part from remote storage"),
));
}
};
let index_part = match index_part {
MaybeDeletedIndexPart::IndexPart(index_part) => index_part,
MaybeDeletedIndexPart::Deleted(_index_part) => {
info!("timeline is deleted according to index_part.json");
return Err(TimelineArchivalError::NotFound);
}
};
let remote_metadata = index_part.metadata.clone();
let timeline_resources = self.build_timeline_resources(timeline_id);
self.load_remote_timeline(
timeline_id,
index_part,
remote_metadata,
timeline_resources,
&ctx,
)
.await
.with_context(|| {
format!(
"failed to load remote timeline {} for tenant {}",
timeline_id, self.tenant_shard_id
)
})?;
let timelines = self.timelines.lock().unwrap();
if let Some(timeline) = timelines.get(&timeline_id) {
let mut offloaded_timelines = self.timelines_offloaded.lock().unwrap();
if offloaded_timelines.remove(&timeline_id).is_none() {
warn!("timeline already removed from offloaded timelines");
}
Ok(Arc::clone(timeline))
} else {
warn!("timeline not available directly after attach");
Err(TimelineArchivalError::Other(anyhow::anyhow!(
"timeline not available directly after attach"
)))
}
}
pub(crate) async fn apply_timeline_archival_config(
self: &Arc<Self>,
&self,
timeline_id: TimelineId,
new_state: TimelineArchivalState,
ctx: RequestContext,
state: TimelineArchivalState,
) -> Result<(), TimelineArchivalError> {
info!("setting timeline archival config");
// First part: figure out what is needed to do, and do validation
let timeline_or_unarchive_offloaded = 'outer: {
let timeline = {
let timelines = self.timelines.lock().unwrap();
let Some(timeline) = timelines.get(&timeline_id) else {
let offloaded_timelines = self.timelines_offloaded.lock().unwrap();
let Some(offloaded) = offloaded_timelines.get(&timeline_id) else {
return Err(TimelineArchivalError::NotFound);
};
if new_state == TimelineArchivalState::Archived {
// It's offloaded already, so nothing to do
return Ok(());
}
if let Some(ancestor_timeline_id) = offloaded.ancestor_timeline_id {
Self::check_ancestor_of_to_be_unarchived_is_not_archived(
ancestor_timeline_id,
&timelines,
&offloaded_timelines,
)?;
}
break 'outer None;
return Err(TimelineArchivalError::NotFound);
};
// Do some validation. We release the timelines lock below, so there is potential
// for race conditions: these checks are more present to prevent misunderstandings of
// the API's capabilities, instead of serving as the sole way to defend their invariants.
match new_state {
TimelineArchivalState::Unarchived => {
Self::check_to_be_unarchived_timeline_has_no_archived_parent(timeline)?
}
TimelineArchivalState::Archived => {
Self::check_to_be_archived_has_no_unarchived_children(timeline_id, &timelines)?
if state == TimelineArchivalState::Unarchived {
if let Some(ancestor_timeline) = timeline.ancestor_timeline() {
if ancestor_timeline.is_archived() == Some(true) {
return Err(TimelineArchivalError::HasArchivedParent(
ancestor_timeline.timeline_id,
));
}
}
}
Some(Arc::clone(timeline))
// Ensure that there are no non-archived child timelines
let children: Vec<TimelineId> = timelines
.iter()
.filter_map(|(id, entry)| {
if entry.get_ancestor_timeline_id() != Some(timeline_id) {
return None;
}
if entry.is_archived() == Some(true) {
return None;
}
Some(*id)
})
.collect();
if !children.is_empty() && state == TimelineArchivalState::Archived {
return Err(TimelineArchivalError::HasUnarchivedChildren(children));
}
Arc::clone(timeline)
};
// Second part: unarchive timeline (if needed)
let timeline = if let Some(timeline) = timeline_or_unarchive_offloaded {
timeline
} else {
// Turn offloaded timeline into a non-offloaded one
self.unoffload_timeline(timeline_id, ctx).await?
};
// Third part: upload new timeline archival state and block until it is present in S3
let upload_needed = timeline
.remote_client
.schedule_index_upload_for_timeline_archival_state(new_state)?;
.schedule_index_upload_for_timeline_archival_state(state)?;
if upload_needed {
info!("Uploading new state");
@@ -2088,7 +1884,7 @@ impl Tenant {
///
/// Returns whether we have pending compaction task.
async fn compaction_iteration(
self: &Arc<Self>,
&self,
cancel: &CancellationToken,
ctx: &RequestContext,
) -> Result<bool, timeline::CompactionError> {
@@ -2109,28 +1905,21 @@ impl Tenant {
// while holding the lock. Then drop the lock and actually perform the
// compactions. We don't want to block everything else while the
// compaction runs.
let timelines_to_compact_or_offload;
{
let timelines_to_compact = {
let timelines = self.timelines.lock().unwrap();
timelines_to_compact_or_offload = timelines
let timelines_to_compact = timelines
.iter()
.filter_map(|(timeline_id, timeline)| {
let (is_active, can_offload) = (timeline.is_active(), timeline.can_offload());
let has_no_unoffloaded_children = {
!timelines
.iter()
.any(|(_id, tl)| tl.get_ancestor_timeline_id() == Some(*timeline_id))
};
let can_offload = can_offload && has_no_unoffloaded_children;
if (is_active, can_offload) == (false, false) {
None
if timeline.is_active() {
Some((*timeline_id, timeline.clone()))
} else {
Some((*timeline_id, timeline.clone(), (is_active, can_offload)))
None
}
})
.collect::<Vec<_>>();
drop(timelines);
}
timelines_to_compact
};
// Before doing any I/O work, check our circuit breaker
if self.compaction_circuit_breaker.lock().unwrap().is_broken() {
@@ -2140,34 +1929,20 @@ impl Tenant {
let mut has_pending_task = false;
for (timeline_id, timeline, (can_compact, can_offload)) in &timelines_to_compact_or_offload
{
let pending_task_left = if *can_compact {
Some(
timeline
.compact(cancel, EnumSet::empty(), ctx)
.instrument(info_span!("compact_timeline", %timeline_id))
.await
.inspect_err(|e| match e {
timeline::CompactionError::ShuttingDown => (),
timeline::CompactionError::Other(e) => {
self.compaction_circuit_breaker
.lock()
.unwrap()
.fail(&CIRCUIT_BREAKERS_BROKEN, e);
}
})?,
)
} else {
None
};
has_pending_task |= pending_task_left.unwrap_or(false);
if pending_task_left == Some(false) && *can_offload {
offload_timeline(self, timeline)
.instrument(info_span!("offload_timeline", %timeline_id))
.await
.map_err(timeline::CompactionError::Other)?;
}
for (timeline_id, timeline) in &timelines_to_compact {
has_pending_task |= timeline
.compact(cancel, EnumSet::empty(), ctx)
.instrument(info_span!("compact_timeline", %timeline_id))
.await
.inspect_err(|e| match e {
timeline::CompactionError::ShuttingDown => (),
timeline::CompactionError::Other(e) => {
self.compaction_circuit_breaker
.lock()
.unwrap()
.fail(&CIRCUIT_BREAKERS_BROKEN, e);
}
})?;
}
self.compaction_circuit_breaker
@@ -3077,7 +2852,6 @@ impl Tenant {
constructed_at: Instant::now(),
timelines: Mutex::new(HashMap::new()),
timelines_creating: Mutex::new(HashSet::new()),
timelines_offloaded: Mutex::new(HashMap::new()),
gc_cs: tokio::sync::Mutex::new(()),
walredo_mgr,
remote_storage,

View File

@@ -141,14 +141,14 @@ impl GcBlock {
Ok(())
}
pub(crate) fn before_delete(&self, timeline_id: &super::TimelineId) {
pub(crate) fn before_delete(&self, timeline: &super::Timeline) {
let unblocked = {
let mut g = self.reasons.lock().unwrap();
if g.is_empty() {
return;
}
g.remove(timeline_id);
g.remove(&timeline.timeline_id);
BlockingReasons::clean_and_summarize(g).is_none()
};

View File

@@ -950,7 +950,6 @@ impl<'a> TenantDownloader<'a> {
let cancel = &self.secondary_state.cancel;
let opts = DownloadOpts {
etag: prev_etag.cloned(),
..Default::default()
};
backoff::retry(

View File

@@ -7,7 +7,6 @@ pub(crate) mod handle;
mod init;
pub mod layer_manager;
pub(crate) mod logical_size;
pub mod offload;
pub mod span;
pub mod uninit;
mod walreceiver;
@@ -1557,17 +1556,6 @@ impl Timeline {
}
}
/// Checks if the internal state of the timeline is consistent with it being able to be offloaded.
/// This is neccessary but not sufficient for offloading of the timeline as it might have
/// child timelines that are not offloaded yet.
pub(crate) fn can_offload(&self) -> bool {
if self.remote_client.is_archived() != Some(true) {
return false;
}
true
}
/// Outermost timeline compaction operation; downloads needed layers. Returns whether we have pending
/// compaction tasks.
pub(crate) async fn compact(
@@ -1830,6 +1818,7 @@ impl Timeline {
self.current_state() == TimelineState::Active
}
#[allow(unused)]
pub(crate) fn is_archived(&self) -> Option<bool> {
self.remote_client.is_archived()
}

View File

@@ -15,7 +15,7 @@ use crate::{
tenant::{
metadata::TimelineMetadata,
remote_timeline_client::{PersistIndexPartWithDeletedFlagError, RemoteTimelineClient},
CreateTimelineCause, DeleteTimelineError, Tenant, TimelineOrOffloaded,
CreateTimelineCause, DeleteTimelineError, Tenant,
},
};
@@ -24,14 +24,12 @@ use super::{Timeline, TimelineResources};
/// Mark timeline as deleted in S3 so we won't pick it up next time
/// during attach or pageserver restart.
/// See comment in persist_index_part_with_deleted_flag.
async fn set_deleted_in_remote_index(
timeline: &TimelineOrOffloaded,
) -> Result<(), DeleteTimelineError> {
let res = timeline
.remote_client()
async fn set_deleted_in_remote_index(timeline: &Timeline) -> Result<(), DeleteTimelineError> {
match timeline
.remote_client
.persist_index_part_with_deleted_flag()
.await;
match res {
.await
{
// If we (now, or already) marked it successfully as deleted, we can proceed
Ok(()) | Err(PersistIndexPartWithDeletedFlagError::AlreadyDeleted(_)) => (),
// Bail out otherwise
@@ -129,9 +127,9 @@ pub(super) async fn delete_local_timeline_directory(
}
/// Removes remote layers and an index file after them.
async fn delete_remote_layers_and_index(timeline: &TimelineOrOffloaded) -> anyhow::Result<()> {
async fn delete_remote_layers_and_index(timeline: &Timeline) -> anyhow::Result<()> {
timeline
.remote_client()
.remote_client
.delete_all()
.await
.context("delete_all")
@@ -139,41 +137,27 @@ async fn delete_remote_layers_and_index(timeline: &TimelineOrOffloaded) -> anyho
/// It is important that this gets called when DeletionGuard is being held.
/// For more context see comments in [`DeleteTimelineFlow::prepare`]
async fn remove_maybe_offloaded_timeline_from_tenant(
async fn remove_timeline_from_tenant(
tenant: &Tenant,
timeline: &TimelineOrOffloaded,
timeline: &Timeline,
_: &DeletionGuard, // using it as a witness
) -> anyhow::Result<()> {
// Remove the timeline from the map.
// This observes the locking order between timelines and timelines_offloaded
let mut timelines = tenant.timelines.lock().unwrap();
let mut timelines_offloaded = tenant.timelines_offloaded.lock().unwrap();
let offloaded_children_exist = timelines_offloaded
.iter()
.any(|(_, entry)| entry.ancestor_timeline_id == Some(timeline.timeline_id()));
let children_exist = timelines
.iter()
.any(|(_, entry)| entry.get_ancestor_timeline_id() == Some(timeline.timeline_id()));
// XXX this can happen because of race conditions with branch creation.
// We already deleted the remote layer files, so it's probably best to panic.
if children_exist || offloaded_children_exist {
.any(|(_, entry)| entry.get_ancestor_timeline_id() == Some(timeline.timeline_id));
// XXX this can happen because `branch_timeline` doesn't check `TimelineState::Stopping`.
// We already deleted the layer files, so it's probably best to panic.
// (Ideally, above remove_dir_all is atomic so we don't see this timeline after a restart)
if children_exist {
panic!("Timeline grew children while we removed layer files");
}
match timeline {
TimelineOrOffloaded::Timeline(timeline) => {
timelines.remove(&timeline.timeline_id).expect(
"timeline that we were deleting was concurrently removed from 'timelines' map",
);
}
TimelineOrOffloaded::Offloaded(timeline) => {
timelines_offloaded
.remove(&timeline.timeline_id)
.expect("timeline that we were deleting was concurrently removed from 'timelines_offloaded' map");
}
}
timelines
.remove(&timeline.timeline_id)
.expect("timeline that we were deleting was concurrently removed from 'timelines' map");
drop(timelines_offloaded);
drop(timelines);
Ok(())
@@ -223,11 +207,9 @@ impl DeleteTimelineFlow {
guard.mark_in_progress()?;
// Now that the Timeline is in Stopping state, request all the related tasks to shut down.
if let TimelineOrOffloaded::Timeline(timeline) = &timeline {
timeline.shutdown(super::ShutdownMode::Hard).await;
}
timeline.shutdown(super::ShutdownMode::Hard).await;
tenant.gc_block.before_delete(&timeline.timeline_id());
tenant.gc_block.before_delete(&timeline);
fail::fail_point!("timeline-delete-before-index-deleted-at", |_| {
Err(anyhow::anyhow!(
@@ -303,16 +285,15 @@ impl DeleteTimelineFlow {
guard.mark_in_progress()?;
let timeline = TimelineOrOffloaded::Timeline(timeline);
Self::schedule_background(guard, tenant.conf, tenant, timeline);
Ok(())
}
pub(super) fn prepare(
fn prepare(
tenant: &Tenant,
timeline_id: TimelineId,
) -> Result<(TimelineOrOffloaded, DeletionGuard), DeleteTimelineError> {
) -> Result<(Arc<Timeline>, DeletionGuard), DeleteTimelineError> {
// Note the interaction between this guard and deletion guard.
// Here we attempt to lock deletion guard when we're holding a lock on timelines.
// This is important because when you take into account `remove_timeline_from_tenant`
@@ -326,14 +307,8 @@ impl DeleteTimelineFlow {
let timelines = tenant.timelines.lock().unwrap();
let timeline = match timelines.get(&timeline_id) {
Some(t) => TimelineOrOffloaded::Timeline(Arc::clone(t)),
None => {
let offloaded_timelines = tenant.timelines_offloaded.lock().unwrap();
match offloaded_timelines.get(&timeline_id) {
Some(t) => TimelineOrOffloaded::Offloaded(Arc::clone(t)),
None => return Err(DeleteTimelineError::NotFound),
}
}
Some(t) => t,
None => return Err(DeleteTimelineError::NotFound),
};
// Ensure that there are no child timelines **attached to that pageserver**,
@@ -359,32 +334,30 @@ impl DeleteTimelineFlow {
// to remove the timeline from it.
// Always if you have two locks that are taken in different order this can result in a deadlock.
let delete_progress = Arc::clone(timeline.delete_progress());
let delete_progress = Arc::clone(&timeline.delete_progress);
let delete_lock_guard = match delete_progress.try_lock_owned() {
Ok(guard) => DeletionGuard(guard),
Err(_) => {
// Unfortunately if lock fails arc is consumed.
return Err(DeleteTimelineError::AlreadyInProgress(Arc::clone(
timeline.delete_progress(),
&timeline.delete_progress,
)));
}
};
if let TimelineOrOffloaded::Timeline(timeline) = &timeline {
timeline.set_state(TimelineState::Stopping);
}
timeline.set_state(TimelineState::Stopping);
Ok((timeline, delete_lock_guard))
Ok((Arc::clone(timeline), delete_lock_guard))
}
fn schedule_background(
guard: DeletionGuard,
conf: &'static PageServerConf,
tenant: Arc<Tenant>,
timeline: TimelineOrOffloaded,
timeline: Arc<Timeline>,
) {
let tenant_shard_id = timeline.tenant_shard_id();
let timeline_id = timeline.timeline_id();
let tenant_shard_id = timeline.tenant_shard_id;
let timeline_id = timeline.timeline_id;
task_mgr::spawn(
task_mgr::BACKGROUND_RUNTIME.handle(),
@@ -395,9 +368,7 @@ impl DeleteTimelineFlow {
async move {
if let Err(err) = Self::background(guard, conf, &tenant, &timeline).await {
error!("Error: {err:#}");
if let TimelineOrOffloaded::Timeline(timeline) = timeline {
timeline.set_broken(format!("{err:#}"))
}
timeline.set_broken(format!("{err:#}"))
};
Ok(())
}
@@ -409,19 +380,15 @@ impl DeleteTimelineFlow {
mut guard: DeletionGuard,
conf: &PageServerConf,
tenant: &Tenant,
timeline: &TimelineOrOffloaded,
timeline: &Timeline,
) -> Result<(), DeleteTimelineError> {
// Offloaded timelines have no local state
// TODO: once we persist offloaded information, delete the timeline from there, too
if let TimelineOrOffloaded::Timeline(timeline) = timeline {
delete_local_timeline_directory(conf, tenant.tenant_shard_id, timeline).await?;
}
delete_local_timeline_directory(conf, tenant.tenant_shard_id, timeline).await?;
delete_remote_layers_and_index(timeline).await?;
pausable_failpoint!("in_progress_delete");
remove_maybe_offloaded_timeline_from_tenant(tenant, timeline, &guard).await?;
remove_timeline_from_tenant(tenant, timeline, &guard).await?;
*guard = Self::Finished;
@@ -433,7 +400,7 @@ impl DeleteTimelineFlow {
}
}
pub(super) struct DeletionGuard(OwnedMutexGuard<DeleteTimelineFlow>);
struct DeletionGuard(OwnedMutexGuard<DeleteTimelineFlow>);
impl Deref for DeletionGuard {
type Target = DeleteTimelineFlow;

View File

@@ -1,69 +0,0 @@
use std::sync::Arc;
use crate::tenant::{OffloadedTimeline, Tenant, TimelineOrOffloaded};
use super::{
delete::{delete_local_timeline_directory, DeleteTimelineFlow, DeletionGuard},
Timeline,
};
pub(crate) async fn offload_timeline(
tenant: &Tenant,
timeline: &Arc<Timeline>,
) -> anyhow::Result<()> {
tracing::info!("offloading archived timeline");
let (timeline, guard) = DeleteTimelineFlow::prepare(tenant, timeline.timeline_id)?;
let TimelineOrOffloaded::Timeline(timeline) = timeline else {
tracing::error!("timeline already offloaded, but given timeline object");
return Ok(());
};
// TODO extend guard mechanism above with method
// to make deletions possible while offloading is in progress
// TODO mark timeline as offloaded in S3
let conf = &tenant.conf;
delete_local_timeline_directory(conf, tenant.tenant_shard_id, &timeline).await?;
remove_timeline_from_tenant(tenant, &timeline, &guard).await?;
{
let mut offloaded_timelines = tenant.timelines_offloaded.lock().unwrap();
offloaded_timelines.insert(
timeline.timeline_id,
Arc::new(OffloadedTimeline::from_timeline(&timeline)),
);
}
Ok(())
}
/// It is important that this gets called when DeletionGuard is being held.
/// For more context see comments in [`DeleteTimelineFlow::prepare`]
async fn remove_timeline_from_tenant(
tenant: &Tenant,
timeline: &Timeline,
_: &DeletionGuard, // using it as a witness
) -> anyhow::Result<()> {
// Remove the timeline from the map.
let mut timelines = tenant.timelines.lock().unwrap();
let children_exist = timelines
.iter()
.any(|(_, entry)| entry.get_ancestor_timeline_id() == Some(timeline.timeline_id));
// XXX this can happen because `branch_timeline` doesn't check `TimelineState::Stopping`.
// We already deleted the layer files, so it's probably best to panic.
// (Ideally, above remove_dir_all is atomic so we don't see this timeline after a restart)
if children_exist {
panic!("Timeline grew children while we removed layer files");
}
timelines
.remove(&timeline.timeline_id)
.expect("timeline that we were deleting was concurrently removed from 'timelines' map");
drop(timelines);
Ok(())
}

View File

@@ -44,6 +44,7 @@ pub(crate) use api::IoMode;
pub(crate) use io_engine::IoEngineKind;
pub(crate) use metadata::Metadata;
pub(crate) use open_options::*;
pub(crate) mod dio;
pub(crate) mod owned_buffers_io {
//! Abstractions for IO with owned buffers.
@@ -55,6 +56,7 @@ pub(crate) mod owned_buffers_io {
//! but for the time being we're proving out the primitives in the neon.git repo
//! for faster iteration.
pub(crate) mod io_buf_aligned;
pub(crate) mod io_buf_ext;
pub(crate) mod slice;
pub(crate) mod write;
@@ -64,22 +66,39 @@ pub(crate) mod owned_buffers_io {
}
#[derive(Debug)]
pub struct VirtualFile {
inner: VirtualFileInner,
_mode: IoMode,
pub enum VirtualFile {
Buffered(VirtualFileInner),
Direct(VirtualFileInner),
}
impl VirtualFile {
fn inner(&self) -> &VirtualFileInner {
match self {
Self::Buffered(file) => file,
Self::Direct(file) => file,
}
}
fn inner_mut(&mut self) -> &mut VirtualFileInner {
match self {
Self::Buffered(file) => file,
Self::Direct(file) => file,
}
}
fn into_inner(self) -> VirtualFileInner {
match self {
Self::Buffered(file) => file,
Self::Direct(file) => file,
}
}
/// Open a file in read-only mode. Like File::open.
pub async fn open<P: AsRef<Utf8Path>>(
path: P,
ctx: &RequestContext,
) -> Result<Self, std::io::Error> {
let inner = VirtualFileInner::open(path, ctx).await?;
Ok(VirtualFile {
inner,
_mode: IoMode::Buffered,
})
let file = VirtualFileInner::open(path, ctx).await?;
Ok(Self::Buffered(file))
}
/// Open a file in read-only mode. Like File::open.
@@ -96,11 +115,8 @@ impl VirtualFile {
path: P,
ctx: &RequestContext,
) -> Result<Self, std::io::Error> {
let inner = VirtualFileInner::create(path, ctx).await?;
Ok(VirtualFile {
inner,
_mode: IoMode::Buffered,
})
let file = VirtualFileInner::create(path, ctx).await?;
Ok(Self::Buffered(file))
}
pub async fn create_v2<P: AsRef<Utf8Path>>(
@@ -120,45 +136,36 @@ impl VirtualFile {
open_options: &OpenOptions,
ctx: &RequestContext, /* TODO: carry a pointer to the metrics in the RequestContext instead of the parsing https://github.com/neondatabase/neon/issues/6107 */
) -> Result<Self, std::io::Error> {
let inner = VirtualFileInner::open_with_options(path, open_options, ctx).await?;
Ok(VirtualFile {
inner,
_mode: IoMode::Buffered,
})
let file = VirtualFileInner::open_with_options(path, open_options, ctx).await?;
Ok(Self::Buffered(file))
}
pub async fn open_with_options_v2<P: AsRef<Utf8Path>>(
path: P,
open_options: &OpenOptions,
open_options: &mut OpenOptions, // Uses `&mut` here to add `O_DIRECT`.
ctx: &RequestContext, /* TODO: carry a pointer to the metrics in the RequestContext instead of the parsing https://github.com/neondatabase/neon/issues/6107 */
) -> Result<Self, std::io::Error> {
let file = match get_io_mode() {
IoMode::Buffered => {
let inner = VirtualFileInner::open_with_options(path, open_options, ctx).await?;
VirtualFile {
inner,
_mode: IoMode::Buffered,
}
let file = VirtualFileInner::open_with_options(path, open_options, ctx).await?;
Self::Buffered(file)
}
#[cfg(target_os = "linux")]
IoMode::Direct => {
let inner = VirtualFileInner::open_with_options(
let file = VirtualFileInner::open_with_options(
path,
open_options.clone().custom_flags(nix::libc::O_DIRECT),
open_options.custom_flags(nix::libc::O_DIRECT),
ctx,
)
.await?;
VirtualFile {
inner,
_mode: IoMode::Direct,
}
Self::Direct(file)
}
};
Ok(file)
}
pub fn path(&self) -> &Utf8Path {
self.inner.path.as_path()
self.inner().path.as_path()
}
pub async fn crashsafe_overwrite<B: BoundedBuf<Buf = Buf> + Send, Buf: IoBuf + Send>(
@@ -170,23 +177,23 @@ impl VirtualFile {
}
pub async fn sync_all(&self) -> Result<(), Error> {
self.inner.sync_all().await
self.inner().sync_all().await
}
pub async fn sync_data(&self) -> Result<(), Error> {
self.inner.sync_data().await
self.inner().sync_data().await
}
pub async fn metadata(&self) -> Result<Metadata, Error> {
self.inner.metadata().await
self.inner().metadata().await
}
pub fn remove(self) {
self.inner.remove();
self.into_inner().remove();
}
pub async fn seek(&mut self, pos: SeekFrom) -> Result<u64, Error> {
self.inner.seek(pos).await
self.inner_mut().seek(pos).await
}
pub async fn read_exact_at<Buf>(
@@ -198,7 +205,7 @@ impl VirtualFile {
where
Buf: IoBufMut + Send,
{
self.inner.read_exact_at(slice, offset, ctx).await
self.inner().read_exact_at(slice, offset, ctx).await
}
pub async fn read_exact_at_page(
@@ -207,7 +214,7 @@ impl VirtualFile {
offset: u64,
ctx: &RequestContext,
) -> Result<PageWriteGuard<'static>, Error> {
self.inner.read_exact_at_page(page, offset, ctx).await
self.inner().read_exact_at_page(page, offset, ctx).await
}
pub async fn write_all_at<Buf: IoBuf + Send>(
@@ -216,7 +223,7 @@ impl VirtualFile {
offset: u64,
ctx: &RequestContext,
) -> (FullSlice<Buf>, Result<(), Error>) {
self.inner.write_all_at(buf, offset, ctx).await
self.inner().write_all_at(buf, offset, ctx).await
}
pub async fn write_all<Buf: IoBuf + Send>(
@@ -224,7 +231,7 @@ impl VirtualFile {
buf: FullSlice<Buf>,
ctx: &RequestContext,
) -> (FullSlice<Buf>, Result<usize, Error>) {
self.inner.write_all(buf, ctx).await
self.inner_mut().write_all(buf, ctx).await
}
}
@@ -1206,11 +1213,11 @@ impl VirtualFile {
blknum: u32,
ctx: &RequestContext,
) -> Result<crate::tenant::block_io::BlockLease<'_>, std::io::Error> {
self.inner.read_blk(blknum, ctx).await
self.inner().read_blk(blknum, ctx).await
}
async fn read_to_end(&mut self, buf: &mut Vec<u8>, ctx: &RequestContext) -> Result<(), Error> {
self.inner.read_to_end(buf, ctx).await
self.inner_mut().read_to_end(buf, ctx).await
}
}
@@ -1357,6 +1364,8 @@ pub(crate) const fn get_io_buffer_alignment() -> usize {
DEFAULT_IO_BUFFER_ALIGNMENT
}
pub(crate) type IoBufferMut = dio::AlignedBufferMut<{ get_io_buffer_alignment() }>;
static IO_MODE: AtomicU8 = AtomicU8::new(IoMode::preferred() as u8);
pub(crate) fn set_io_mode(mode: IoMode) {

View File

@@ -0,0 +1,405 @@
#![allow(unused)]
use core::slice;
use std::{
alloc::{self, Layout},
cmp,
mem::{ManuallyDrop, MaybeUninit},
ops::{Deref, DerefMut},
ptr::{addr_of_mut, NonNull},
};
use bytes::buf::UninitSlice;
struct IoBufferPtr(*mut u8);
// SAFETY: We gurantees no one besides `IoBufferPtr` itself has the raw pointer.
unsafe impl Send for IoBufferPtr {}
/// An aligned buffer type used for I/O.
pub struct AlignedBufferMut<const ALIGN: usize> {
ptr: IoBufferPtr,
capacity: usize,
len: usize,
}
impl<const ALIGN: usize> AlignedBufferMut<ALIGN> {
/// Constructs a new, empty `IoBufferMut` with at least the specified capacity and alignment.
///
/// The buffer will be able to hold at most `capacity` elements and will never resize.
///
///
/// # Panics
///
/// Panics if the new capacity exceeds `isize::MAX` _bytes_, or if the following alignment requirement is not met:
/// * `align` must not be zero,
///
/// * `align` must be a power of two,
///
/// * `capacity`, when rounded up to the nearest multiple of `align`,
/// must not overflow isize (i.e., the rounded value must be
/// less than or equal to `isize::MAX`).
pub fn with_capacity(capacity: usize) -> Self {
let layout = Layout::from_size_align(capacity, ALIGN).expect("Invalid layout");
// SAFETY: Making an allocation with a sized and aligned layout. The memory is manually freed with the same layout.
let ptr = unsafe {
let ptr = alloc::alloc(layout);
if ptr.is_null() {
alloc::handle_alloc_error(layout);
}
IoBufferPtr(ptr)
};
AlignedBufferMut {
ptr,
capacity,
len: 0,
}
}
/// Constructs a new `IoBufferMut` with at least the specified capacity and alignment, filled with zeros.
pub fn with_capacity_zeroed(capacity: usize) -> Self {
use bytes::BufMut;
let mut buf = Self::with_capacity(capacity);
buf.put_bytes(0, capacity);
buf.len = capacity;
buf
}
/// Returns the total number of bytes the buffer can hold.
#[inline]
pub fn capacity(&self) -> usize {
self.capacity
}
/// Returns the alignment of the buffer.
#[inline]
pub const fn align(&self) -> usize {
ALIGN
}
/// Returns the number of bytes in the buffer, also referred to as its 'length'.
#[inline]
pub fn len(&self) -> usize {
self.len
}
/// Force the length of the buffer to `new_len`.
#[inline]
unsafe fn set_len(&mut self, new_len: usize) {
debug_assert!(new_len <= self.capacity());
self.len = new_len;
}
#[inline]
fn as_ptr(&self) -> *const u8 {
self.ptr.0
}
#[inline]
fn as_mut_ptr(&mut self) -> *mut u8 {
self.ptr.0
}
/// Extracts a slice containing the entire buffer.
///
/// Equivalent to `&s[..]`.
#[inline]
fn as_slice(&self) -> &[u8] {
// SAFETY: The pointer is valid and `len` bytes are initialized.
unsafe { slice::from_raw_parts(self.as_ptr(), self.len) }
}
/// Extracts a mutable slice of the entire buffer.
///
/// Equivalent to `&mut s[..]`.
fn as_mut_slice(&mut self) -> &mut [u8] {
// SAFETY: The pointer is valid and `len` bytes are initialized.
unsafe { slice::from_raw_parts_mut(self.as_mut_ptr(), self.len) }
}
/// Drops the all the contents of the buffer, setting its length to `0`.
#[inline]
pub fn clear(&mut self) {
self.len = 0;
}
/// Reserves capacity for at least `additional` more bytes to be inserted
/// in the given `IoBufferMut`. The collection may reserve more space to
/// speculatively avoid frequent reallocations. After calling `reserve`,
/// capacity will be greater than or equal to `self.len() + additional`.
/// Does nothing if capacity is already sufficient.
///
/// # Panics
///
/// Panics if the new capacity exceeds `isize::MAX` _bytes_.
pub fn reserve(&mut self, additional: usize) {
if additional > self.capacity() - self.len() {
self.reserve_inner(additional);
}
}
fn reserve_inner(&mut self, additional: usize) {
let Some(required_cap) = self.len().checked_add(additional) else {
capacity_overflow()
};
let old_capacity = self.capacity();
let align = self.align();
// This guarantees exponential growth. The doubling cannot overflow
// because `cap <= isize::MAX` and the type of `cap` is `usize`.
let cap = cmp::max(old_capacity * 2, required_cap);
if !is_valid_alloc(cap) {
capacity_overflow()
}
let new_layout = Layout::from_size_align(cap, self.align()).expect("Invalid layout");
let old_ptr = self.as_mut_ptr();
// SAFETY: old allocation was allocated with std::alloc::alloc with the same layout,
// and we panics on null pointer.
let (ptr, cap) = unsafe {
let old_layout = Layout::from_size_align_unchecked(old_capacity, align);
let ptr = alloc::realloc(old_ptr, old_layout, new_layout.size());
if ptr.is_null() {
alloc::handle_alloc_error(new_layout);
}
(IoBufferPtr(ptr), cap)
};
self.ptr = ptr;
self.capacity = cap;
}
/// Consumes and leaks the `IoBufferMut`, returning a mutable reference to the contents, &'a mut [u8].
pub fn leak<'a>(self) -> &'a mut [u8] {
let mut buf = ManuallyDrop::new(self);
// SAFETY: leaking the buffer as intended.
unsafe { slice::from_raw_parts_mut(buf.as_mut_ptr(), buf.len) }
}
}
fn capacity_overflow() -> ! {
panic!("capacity overflow")
}
// We need to guarantee the following:
// * We don't ever allocate `> isize::MAX` byte-size objects.
// * We don't overflow `usize::MAX` and actually allocate too little.
//
// On 64-bit we just need to check for overflow since trying to allocate
// `> isize::MAX` bytes will surely fail. On 32-bit and 16-bit we need to add
// an extra guard for this in case we're running on a platform which can use
// all 4GB in user-space, e.g., PAE or x32.
#[inline]
fn is_valid_alloc(alloc_size: usize) -> bool {
!(usize::BITS < 64 && alloc_size > isize::MAX as usize)
}
impl<const ALIGN: usize> Drop for AlignedBufferMut<ALIGN> {
fn drop(&mut self) {
// SAFETY: memory was allocated with std::alloc::alloc with the same layout.
unsafe {
alloc::dealloc(
self.as_mut_ptr(),
Layout::from_size_align_unchecked(self.capacity, ALIGN),
)
}
}
}
impl<const ALIGN: usize> Deref for AlignedBufferMut<ALIGN> {
type Target = [u8];
fn deref(&self) -> &Self::Target {
self.as_slice()
}
}
impl<const ALIGN: usize> DerefMut for AlignedBufferMut<ALIGN> {
fn deref_mut(&mut self) -> &mut Self::Target {
self.as_mut_slice()
}
}
/// SAFETY: When advancing the internal cursor, the caller needs to make sure the bytes advcanced past have been initialized.
unsafe impl<const ALIGN: usize> bytes::BufMut for AlignedBufferMut<ALIGN> {
#[inline]
fn remaining_mut(&self) -> usize {
// Although a `Vec` can have at most isize::MAX bytes, we never want to grow `IoBufferMut`.
// Thus, it can have at most `self.capacity` bytes.
self.capacity() - self.len()
}
// SAFETY: Caller needs to make sure the bytes being advanced past have been initialized.
#[inline]
unsafe fn advance_mut(&mut self, cnt: usize) {
let len = self.len();
let remaining = self.remaining_mut();
if remaining < cnt {
panic_advance(cnt, remaining);
}
// Addition will not overflow since the sum is at most the capacity.
self.set_len(len + cnt);
}
#[inline]
fn chunk_mut(&mut self) -> &mut bytes::buf::UninitSlice {
let cap = self.capacity();
let len = self.len();
// SAFETY: Since `self.ptr` is valid for `cap` bytes, `self.ptr.add(len)` must be
// valid for `cap - len` bytes. The subtraction will not underflow since
// `len <= cap`.
unsafe { UninitSlice::from_raw_parts_mut(self.as_mut_ptr().add(len), cap - len) }
}
}
/// Panic with a nice error message.
#[cold]
fn panic_advance(idx: usize, len: usize) -> ! {
panic!(
"advance out of bounds: the len is {} but advancing by {}",
len, idx
);
}
/// Safety: [`IoBufferMut`] has exclusive ownership of the io buffer,
/// and the location remains stable even if [`Self`] is moved.
unsafe impl<const ALIGN: usize> tokio_epoll_uring::IoBuf for AlignedBufferMut<ALIGN> {
fn stable_ptr(&self) -> *const u8 {
self.as_ptr()
}
fn bytes_init(&self) -> usize {
self.len()
}
fn bytes_total(&self) -> usize {
self.capacity()
}
}
// SAFETY: See above.
unsafe impl<const ALIGN: usize> tokio_epoll_uring::IoBufMut for AlignedBufferMut<ALIGN> {
fn stable_mut_ptr(&mut self) -> *mut u8 {
self.as_mut_ptr()
}
unsafe fn set_init(&mut self, init_len: usize) {
if self.len() < init_len {
self.set_len(init_len);
}
}
}
#[cfg(test)]
mod tests {
use super::*;
const ALIGN: usize = 4 * 1024;
type TestIoBufferMut = AlignedBufferMut<ALIGN>;
#[test]
fn test_with_capacity() {
let v = TestIoBufferMut::with_capacity(ALIGN * 4);
assert_eq!(v.len(), 0);
assert_eq!(v.capacity(), ALIGN * 4);
assert_eq!(v.align(), ALIGN);
assert_eq!(v.as_ptr().align_offset(ALIGN), 0);
let v = TestIoBufferMut::with_capacity(ALIGN / 2);
assert_eq!(v.len(), 0);
assert_eq!(v.capacity(), ALIGN / 2);
assert_eq!(v.align(), ALIGN);
assert_eq!(v.as_ptr().align_offset(ALIGN), 0);
}
#[test]
fn test_with_capacity_zeroed() {
let v = TestIoBufferMut::with_capacity_zeroed(ALIGN);
assert_eq!(v.len(), ALIGN);
assert_eq!(v.capacity(), ALIGN);
assert_eq!(v.align(), ALIGN);
assert_eq!(v.as_ptr().align_offset(ALIGN), 0);
assert_eq!(&v[..], &[0; ALIGN])
}
#[test]
fn test_reserve() {
use bytes::BufMut;
let mut v = TestIoBufferMut::with_capacity(ALIGN);
let capacity = v.capacity();
v.reserve(capacity);
assert_eq!(v.capacity(), capacity);
let data = [b'a'; ALIGN];
v.put(&data[..]);
v.reserve(capacity);
assert!(v.capacity() >= capacity * 2);
assert_eq!(&v[..], &data[..]);
let capacity = v.capacity();
v.clear();
v.reserve(capacity);
assert_eq!(capacity, v.capacity());
}
#[test]
fn test_bytes_put() {
use bytes::BufMut;
let mut v = TestIoBufferMut::with_capacity(ALIGN * 4);
let x = [b'a'; ALIGN];
for _ in 0..2 {
for _ in 0..4 {
v.put(&x[..]);
}
assert_eq!(v.len(), ALIGN * 4);
assert_eq!(v.capacity(), ALIGN * 4);
assert_eq!(v.align(), ALIGN);
assert_eq!(v.as_ptr().align_offset(ALIGN), 0);
v.clear()
}
assert_eq!(v.len(), 0);
assert_eq!(v.capacity(), ALIGN * 4);
assert_eq!(v.align(), ALIGN);
assert_eq!(v.as_ptr().align_offset(ALIGN), 0);
}
#[test]
#[should_panic]
fn test_bytes_put_panic() {
use bytes::BufMut;
const ALIGN: usize = 4 * 1024;
let mut v = TestIoBufferMut::with_capacity(ALIGN * 4);
let x = [b'a'; ALIGN];
for _ in 0..5 {
v.put_slice(&x[..]);
}
}
#[test]
fn test_io_buf_put_slice() {
use tokio_epoll_uring::BoundedBufMut;
const ALIGN: usize = 4 * 1024;
let mut v = TestIoBufferMut::with_capacity(ALIGN);
let x = [b'a'; ALIGN];
for _ in 0..2 {
v.put_slice(&x[..]);
assert_eq!(v.len(), ALIGN);
assert_eq!(v.capacity(), ALIGN);
assert_eq!(v.align(), ALIGN);
assert_eq!(v.as_ptr().align_offset(ALIGN), 0);
v.clear()
}
assert_eq!(v.len(), 0);
assert_eq!(v.capacity(), ALIGN);
assert_eq!(v.align(), ALIGN);
assert_eq!(v.as_ptr().align_offset(ALIGN), 0);
}
}

View File

@@ -0,0 +1,9 @@
#![allow(unused)]
use tokio_epoll_uring::IoBufMut;
use crate::virtual_file::IoBufferMut;
pub(crate) trait IoBufAlignedMut: IoBufMut {}
impl IoBufAlignedMut for IoBufferMut {}

View File

@@ -1,5 +1,6 @@
//! See [`FullSlice`].
use crate::virtual_file::IoBufferMut;
use bytes::{Bytes, BytesMut};
use std::ops::{Deref, Range};
use tokio_epoll_uring::{BoundedBuf, IoBuf, Slice};
@@ -76,3 +77,4 @@ macro_rules! impl_io_buf_ext {
impl_io_buf_ext!(Bytes);
impl_io_buf_ext!(BytesMut);
impl_io_buf_ext!(Vec<u8>);
impl_io_buf_ext!(IoBufferMut);

View File

@@ -146,8 +146,6 @@ ConstructDeltaMessage()
if (RootTable.role_table)
{
JsonbValue roles;
HASH_SEQ_STATUS status;
RoleEntry *entry;
roles.type = jbvString;
roles.val.string.val = "roles";
@@ -155,6 +153,9 @@ ConstructDeltaMessage()
pushJsonbValue(&state, WJB_KEY, &roles);
pushJsonbValue(&state, WJB_BEGIN_ARRAY, NULL);
HASH_SEQ_STATUS status;
RoleEntry *entry;
hash_seq_init(&status, RootTable.role_table);
while ((entry = hash_seq_search(&status)) != NULL)
{
@@ -189,12 +190,10 @@ ConstructDeltaMessage()
}
pushJsonbValue(&state, WJB_END_ARRAY, NULL);
}
{
JsonbValue *result = pushJsonbValue(&state, WJB_END_OBJECT, NULL);
Jsonb *jsonb = JsonbValueToJsonb(result);
JsonbValue *result = pushJsonbValue(&state, WJB_END_OBJECT, NULL);
Jsonb *jsonb = JsonbValueToJsonb(result);
return JsonbToCString(NULL, &jsonb->root, 0 /* estimated_len */ );
}
return JsonbToCString(NULL, &jsonb->root, 0 /* estimated_len */ );
}
#define ERROR_SIZE 1024
@@ -273,28 +272,32 @@ SendDeltasToControlPlane()
curl_easy_setopt(handle, CURLOPT_WRITEFUNCTION, ErrorWriteCallback);
}
char *message = ConstructDeltaMessage();
ErrorString str;
str.size = 0;
curl_easy_setopt(handle, CURLOPT_POSTFIELDS, message);
curl_easy_setopt(handle, CURLOPT_WRITEDATA, &str);
const int num_retries = 5;
CURLcode curl_status;
for (int i = 0; i < num_retries; i++)
{
if ((curl_status = curl_easy_perform(handle)) == 0)
break;
elog(LOG, "Curl request failed on attempt %d: %s", i, CurlErrorBuf);
pg_usleep(1000 * 1000);
}
if (curl_status != CURLE_OK)
{
elog(ERROR, "Failed to perform curl request: %s", CurlErrorBuf);
}
else
{
char *message = ConstructDeltaMessage();
ErrorString str;
const int num_retries = 5;
CURLcode curl_status;
long response_code;
str.size = 0;
curl_easy_setopt(handle, CURLOPT_POSTFIELDS, message);
curl_easy_setopt(handle, CURLOPT_WRITEDATA, &str);
for (int i = 0; i < num_retries; i++)
{
if ((curl_status = curl_easy_perform(handle)) == 0)
break;
elog(LOG, "Curl request failed on attempt %d: %s", i, CurlErrorBuf);
pg_usleep(1000 * 1000);
}
if (curl_status != CURLE_OK)
elog(ERROR, "Failed to perform curl request: %s", CurlErrorBuf);
if (curl_easy_getinfo(handle, CURLINFO_RESPONSE_CODE, &response_code) != CURLE_UNKNOWN_OPTION)
{
if (response_code != 200)
@@ -373,11 +376,10 @@ MergeTable()
if (old_table->db_table)
{
InitDbTableIfNeeded();
DbEntry *entry;
HASH_SEQ_STATUS status;
InitDbTableIfNeeded();
hash_seq_init(&status, old_table->db_table);
while ((entry = hash_seq_search(&status)) != NULL)
{
@@ -419,11 +421,10 @@ MergeTable()
if (old_table->role_table)
{
InitRoleTableIfNeeded();
RoleEntry *entry;
HASH_SEQ_STATUS status;
InitRoleTableIfNeeded();
hash_seq_init(&status, old_table->role_table);
while ((entry = hash_seq_search(&status)) != NULL)
{
@@ -514,12 +515,9 @@ RoleIsNeonSuperuser(const char *role_name)
static void
HandleCreateDb(CreatedbStmt *stmt)
{
InitDbTableIfNeeded();
DefElem *downer = NULL;
ListCell *option;
bool found = false;
DbEntry *entry;
InitDbTableIfNeeded();
foreach(option, stmt->options)
{
@@ -528,11 +526,13 @@ HandleCreateDb(CreatedbStmt *stmt)
if (strcmp(defel->defname, "owner") == 0)
downer = defel;
}
bool found = false;
DbEntry *entry = hash_search(
CurrentDdlTable->db_table,
stmt->dbname,
HASH_ENTER,
&found);
entry = hash_search(CurrentDdlTable->db_table,
stmt->dbname,
HASH_ENTER,
&found);
if (!found)
memset(entry->old_name, 0, sizeof(entry->old_name));
@@ -554,24 +554,21 @@ HandleCreateDb(CreatedbStmt *stmt)
static void
HandleAlterOwner(AlterOwnerStmt *stmt)
{
const char *name;
bool found = false;
DbEntry *entry;
const char *new_owner;
if (stmt->objectType != OBJECT_DATABASE)
return;
InitDbTableIfNeeded();
const char *name = strVal(stmt->object);
bool found = false;
DbEntry *entry = hash_search(
CurrentDdlTable->db_table,
name,
HASH_ENTER,
&found);
name = strVal(stmt->object);
entry = hash_search(CurrentDdlTable->db_table,
name,
HASH_ENTER,
&found);
if (!found)
memset(entry->old_name, 0, sizeof(entry->old_name));
const char *new_owner = get_rolespec_name(stmt->newowner);
new_owner = get_rolespec_name(stmt->newowner);
if (RoleIsNeonSuperuser(new_owner))
elog(ERROR, "can't alter owner to neon_superuser");
entry->owner = get_role_oid(new_owner, false);
@@ -581,23 +578,21 @@ HandleAlterOwner(AlterOwnerStmt *stmt)
static void
HandleDbRename(RenameStmt *stmt)
{
bool found = false;
DbEntry *entry;
DbEntry *entry_for_new_name;
Assert(stmt->renameType == OBJECT_DATABASE);
InitDbTableIfNeeded();
entry = hash_search(CurrentDdlTable->db_table,
stmt->subname,
HASH_FIND,
&found);
bool found = false;
DbEntry *entry = hash_search(
CurrentDdlTable->db_table,
stmt->subname,
HASH_FIND,
&found);
DbEntry *entry_for_new_name = hash_search(
CurrentDdlTable->db_table,
stmt->newname,
HASH_ENTER,
NULL);
entry_for_new_name = hash_search(CurrentDdlTable->db_table,
stmt->newname,
HASH_ENTER,
NULL);
entry_for_new_name->type = Op_Set;
if (found)
{
if (entry->old_name[0] != '\0')
@@ -605,7 +600,8 @@ HandleDbRename(RenameStmt *stmt)
else
strlcpy(entry_for_new_name->old_name, entry->name, NAMEDATALEN);
entry_for_new_name->owner = entry->owner;
hash_search(CurrentDdlTable->db_table,
hash_search(
CurrentDdlTable->db_table,
stmt->subname,
HASH_REMOVE,
NULL);
@@ -620,15 +616,14 @@ HandleDbRename(RenameStmt *stmt)
static void
HandleDropDb(DropdbStmt *stmt)
{
bool found = false;
DbEntry *entry;
InitDbTableIfNeeded();
bool found = false;
DbEntry *entry = hash_search(
CurrentDdlTable->db_table,
stmt->dbname,
HASH_ENTER,
&found);
entry = hash_search(CurrentDdlTable->db_table,
stmt->dbname,
HASH_ENTER,
&found);
entry->type = Op_Delete;
entry->owner = InvalidOid;
if (!found)
@@ -638,14 +633,16 @@ HandleDropDb(DropdbStmt *stmt)
static void
HandleCreateRole(CreateRoleStmt *stmt)
{
InitRoleTableIfNeeded();
bool found = false;
RoleEntry *entry;
DefElem *dpass;
RoleEntry *entry = hash_search(
CurrentDdlTable->role_table,
stmt->role,
HASH_ENTER,
&found);
DefElem *dpass = NULL;
ListCell *option;
InitRoleTableIfNeeded();
dpass = NULL;
foreach(option, stmt->options)
{
DefElem *defel = lfirst(option);
@@ -653,11 +650,6 @@ HandleCreateRole(CreateRoleStmt *stmt)
if (strcmp(defel->defname, "password") == 0)
dpass = defel;
}
entry = hash_search(CurrentDdlTable->role_table,
stmt->role,
HASH_ENTER,
&found);
if (!found)
memset(entry->old_name, 0, sizeof(entry->old_name));
if (dpass && dpass->arg)
@@ -670,18 +662,14 @@ HandleCreateRole(CreateRoleStmt *stmt)
static void
HandleAlterRole(AlterRoleStmt *stmt)
{
const char *role_name = stmt->role->rolename;
DefElem *dpass;
ListCell *option;
bool found = false;
RoleEntry *entry;
InitRoleTableIfNeeded();
DefElem *dpass = NULL;
ListCell *option;
const char *role_name = stmt->role->rolename;
if (RoleIsNeonSuperuser(role_name) && !superuser())
elog(ERROR, "can't ALTER neon_superuser");
dpass = NULL;
foreach(option, stmt->options)
{
DefElem *defel = lfirst(option);
@@ -692,11 +680,13 @@ HandleAlterRole(AlterRoleStmt *stmt)
/* We only care about updates to the password */
if (!dpass)
return;
bool found = false;
RoleEntry *entry = hash_search(
CurrentDdlTable->role_table,
role_name,
HASH_ENTER,
&found);
entry = hash_search(CurrentDdlTable->role_table,
role_name,
HASH_ENTER,
&found);
if (!found)
memset(entry->old_name, 0, sizeof(entry->old_name));
if (dpass->arg)
@@ -709,22 +699,20 @@ HandleAlterRole(AlterRoleStmt *stmt)
static void
HandleRoleRename(RenameStmt *stmt)
{
bool found = false;
RoleEntry *entry;
RoleEntry *entry_for_new_name;
Assert(stmt->renameType == OBJECT_ROLE);
InitRoleTableIfNeeded();
Assert(stmt->renameType == OBJECT_ROLE);
bool found = false;
RoleEntry *entry = hash_search(
CurrentDdlTable->role_table,
stmt->subname,
HASH_FIND,
&found);
entry = hash_search(CurrentDdlTable->role_table,
stmt->subname,
HASH_FIND,
&found);
entry_for_new_name = hash_search(CurrentDdlTable->role_table,
stmt->newname,
HASH_ENTER,
NULL);
RoleEntry *entry_for_new_name = hash_search(
CurrentDdlTable->role_table,
stmt->newname,
HASH_ENTER,
NULL);
entry_for_new_name->type = Op_Set;
if (found)
@@ -750,9 +738,8 @@ HandleRoleRename(RenameStmt *stmt)
static void
HandleDropRole(DropRoleStmt *stmt)
{
ListCell *item;
InitRoleTableIfNeeded();
ListCell *item;
foreach(item, stmt->roles)
{

View File

@@ -170,14 +170,12 @@ lfc_disable(char const *op)
if (lfc_desc > 0)
{
int rc;
/*
* If the reason of error is ENOSPC, then truncation of file may
* help to reclaim some space
*/
pgstat_report_wait_start(WAIT_EVENT_NEON_LFC_TRUNCATE);
rc = ftruncate(lfc_desc, 0);
int rc = ftruncate(lfc_desc, 0);
pgstat_report_wait_end();
if (rc < 0)
@@ -618,7 +616,7 @@ lfc_evict(NRelFileInfo rinfo, ForkNumber forkNum, BlockNumber blkno)
*/
if (entry->bitmap[chunk_offs >> 5] == 0)
{
bool has_remaining_pages = false;
bool has_remaining_pages;
for (int i = 0; i < CHUNK_BITMAP_SIZE; i++)
{
@@ -668,6 +666,7 @@ lfc_readv_select(NRelFileInfo rinfo, ForkNumber forkNum, BlockNumber blkno,
BufferTag tag;
FileCacheEntry *entry;
ssize_t rc;
bool result = true;
uint32 hash;
uint64 generation;
uint32 entry_offset;
@@ -926,10 +925,10 @@ lfc_writev(NRelFileInfo rinfo, ForkNumber forkNum, BlockNumber blkno,
/* We can reuse a hole that was left behind when the LFC was shrunk previously */
FileCacheEntry *hole = dlist_container(FileCacheEntry, list_node, dlist_pop_head_node(&lfc_ctl->holes));
uint32 offset = hole->offset;
bool hole_found;
bool found;
hash_search_with_hash_value(lfc_hash, &hole->key, hole->hash, HASH_REMOVE, &hole_found);
CriticalAssert(hole_found);
hash_search_with_hash_value(lfc_hash, &hole->key, hole->hash, HASH_REMOVE, &found);
CriticalAssert(found);
lfc_ctl->used += 1;
entry->offset = offset; /* reuse the hole */
@@ -1005,7 +1004,7 @@ neon_get_lfc_stats(PG_FUNCTION_ARGS)
Datum result;
HeapTuple tuple;
char const *key;
uint64 value = 0;
uint64 value;
Datum values[NUM_NEON_GET_STATS_COLS];
bool nulls[NUM_NEON_GET_STATS_COLS];

View File

@@ -116,6 +116,8 @@ addSHLL(HyperLogLogState *cState, uint32 hash)
{
uint8 count;
uint32 index;
size_t i;
size_t j;
TimestampTz now = GetCurrentTimestamp();
/* Use the first "k" (registerWidth) bits as a zero based index */

View File

@@ -89,6 +89,7 @@ typedef struct
#if PG_VERSION_NUM >= 150000
static shmem_request_hook_type prev_shmem_request_hook = NULL;
static void walproposer_shmem_request(void);
#endif
static shmem_startup_hook_type prev_shmem_startup_hook;
static PagestoreShmemState *pagestore_shared;
@@ -440,8 +441,8 @@ pageserver_connect(shardno_t shard_no, int elevel)
return false;
}
shard->state = PS_Connecting_Startup;
/* fallthrough */
}
/* FALLTHROUGH */
case PS_Connecting_Startup:
{
char *pagestream_query;
@@ -452,6 +453,8 @@ pageserver_connect(shardno_t shard_no, int elevel)
do
{
WaitEvent event;
switch (poll_result)
{
default: /* unknown/unused states are handled as a failed connection */
@@ -582,8 +585,8 @@ pageserver_connect(shardno_t shard_no, int elevel)
}
shard->state = PS_Connecting_PageStream;
/* fallthrough */
}
/* FALLTHROUGH */
case PS_Connecting_PageStream:
{
neon_shard_log(shard_no, DEBUG5, "Connection state: Connecting_PageStream");
@@ -628,8 +631,8 @@ pageserver_connect(shardno_t shard_no, int elevel)
}
shard->state = PS_Connected;
/* fallthrough */
}
/* FALLTHROUGH */
case PS_Connected:
/*
* We successfully connected. Future connections to this PageServer

View File

@@ -94,6 +94,7 @@ neon_perf_counters_to_metrics(neon_per_backend_counters *counters)
metric_t *metrics = palloc((NUM_METRICS + 1) * sizeof(metric_t));
uint64 bucket_accum;
int i = 0;
Datum getpage_wait_str;
metrics[i].name = "getpage_wait_seconds_count";
metrics[i].is_bucket = false;
@@ -223,6 +224,7 @@ neon_get_perf_counters(PG_FUNCTION_ARGS)
ReturnSetInfo *rsinfo = (ReturnSetInfo *) fcinfo->resultinfo;
Datum values[3];
bool nulls[3];
Datum getpage_wait_str;
neon_per_backend_counters totals = {0};
metric_t *metrics;

View File

@@ -213,6 +213,32 @@ extern const f_smgr *smgr_neon(ProcNumber backend, NRelFileInfo rinfo);
extern void smgr_init_neon(void);
extern void readahead_buffer_resize(int newsize, void *extra);
/* Neon storage manager functionality */
extern void neon_init(void);
extern void neon_open(SMgrRelation reln);
extern void neon_close(SMgrRelation reln, ForkNumber forknum);
extern void neon_create(SMgrRelation reln, ForkNumber forknum, bool isRedo);
extern bool neon_exists(SMgrRelation reln, ForkNumber forknum);
extern void neon_unlink(NRelFileInfoBackend rnode, ForkNumber forknum, bool isRedo);
#if PG_MAJORVERSION_NUM < 16
extern void neon_extend(SMgrRelation reln, ForkNumber forknum,
BlockNumber blocknum, char *buffer, bool skipFsync);
#else
extern void neon_extend(SMgrRelation reln, ForkNumber forknum,
BlockNumber blocknum, const void *buffer, bool skipFsync);
extern void neon_zeroextend(SMgrRelation reln, ForkNumber forknum,
BlockNumber blocknum, int nbuffers, bool skipFsync);
#endif
#if PG_MAJORVERSION_NUM >=17
extern bool neon_prefetch(SMgrRelation reln, ForkNumber forknum,
BlockNumber blocknum, int nblocks);
#else
extern bool neon_prefetch(SMgrRelation reln, ForkNumber forknum,
BlockNumber blocknum);
#endif
/*
* LSN values associated with each request to the pageserver
*/
@@ -252,7 +278,13 @@ extern PGDLLEXPORT void neon_read_at_lsn(NRelFileInfo rnode, ForkNumber forkNum,
extern PGDLLEXPORT void neon_read_at_lsn(NRelFileInfo rnode, ForkNumber forkNum, BlockNumber blkno,
neon_request_lsns request_lsns, void *buffer);
#endif
extern void neon_writeback(SMgrRelation reln, ForkNumber forknum,
BlockNumber blocknum, BlockNumber nblocks);
extern BlockNumber neon_nblocks(SMgrRelation reln, ForkNumber forknum);
extern int64 neon_dbsize(Oid dbNode);
extern void neon_truncate(SMgrRelation reln, ForkNumber forknum,
BlockNumber nblocks);
extern void neon_immedsync(SMgrRelation reln, ForkNumber forknum);
/* utils for neon relsize cache */
extern void relsize_hash_init(void);

View File

@@ -118,8 +118,6 @@ static UnloggedBuildPhase unlogged_build_phase = UNLOGGED_BUILD_NOT_IN_PROGRESS;
static bool neon_redo_read_buffer_filter(XLogReaderState *record, uint8 block_id);
static bool (*old_redo_read_buffer_filter) (XLogReaderState *record, uint8 block_id) = NULL;
static BlockNumber neon_nblocks(SMgrRelation reln, ForkNumber forknum);
/*
* Prefetch implementation:
*
@@ -738,7 +736,7 @@ static void
prefetch_do_request(PrefetchRequest *slot, neon_request_lsns *force_request_lsns)
{
bool found;
uint64 mySlotNo PG_USED_FOR_ASSERTS_ONLY = slot->my_ring_index;
uint64 mySlotNo = slot->my_ring_index;
NeonGetPageRequest request = {
.req.tag = T_NeonGetPageRequest,
@@ -1465,6 +1463,7 @@ log_newpages_copy(NRelFileInfo * rinfo, ForkNumber forkNum, BlockNumber blkno,
BlockNumber blknos[XLR_MAX_BLOCK_ID];
Page pageptrs[XLR_MAX_BLOCK_ID];
int nregistered = 0;
XLogRecPtr result = 0;
for (int i = 0; i < nblocks; i++)
{
@@ -1777,7 +1776,7 @@ neon_wallog_page(SMgrRelation reln, ForkNumber forknum, BlockNumber blocknum, co
/*
* neon_init() -- Initialize private state
*/
static void
void
neon_init(void)
{
Size prfs_size;
@@ -2167,7 +2166,7 @@ neon_prefetch_response_usable(neon_request_lsns *request_lsns,
/*
* neon_exists() -- Does the physical file exist?
*/
static bool
bool
neon_exists(SMgrRelation reln, ForkNumber forkNum)
{
bool exists;
@@ -2273,7 +2272,7 @@ neon_exists(SMgrRelation reln, ForkNumber forkNum)
*
* If isRedo is true, it's okay for the relation to exist already.
*/
static void
void
neon_create(SMgrRelation reln, ForkNumber forkNum, bool isRedo)
{
switch (reln->smgr_relpersistence)
@@ -2349,7 +2348,7 @@ neon_create(SMgrRelation reln, ForkNumber forkNum, bool isRedo)
* Note: any failure should be reported as WARNING not ERROR, because
* we are usually not in a transaction anymore when this is called.
*/
static void
void
neon_unlink(NRelFileInfoBackend rinfo, ForkNumber forkNum, bool isRedo)
{
/*
@@ -2373,7 +2372,7 @@ neon_unlink(NRelFileInfoBackend rinfo, ForkNumber forkNum, bool isRedo)
* EOF). Note that we assume writing a block beyond current EOF
* causes intervening file space to become filled with zeroes.
*/
static void
void
#if PG_MAJORVERSION_NUM < 16
neon_extend(SMgrRelation reln, ForkNumber forkNum, BlockNumber blkno,
char *buffer, bool skipFsync)
@@ -2465,7 +2464,7 @@ neon_extend(SMgrRelation reln, ForkNumber forkNum, BlockNumber blkno,
}
#if PG_MAJORVERSION_NUM >= 16
static void
void
neon_zeroextend(SMgrRelation reln, ForkNumber forkNum, BlockNumber blocknum,
int nblocks, bool skipFsync)
{
@@ -2561,7 +2560,7 @@ neon_zeroextend(SMgrRelation reln, ForkNumber forkNum, BlockNumber blocknum,
/*
* neon_open() -- Initialize newly-opened relation.
*/
static void
void
neon_open(SMgrRelation reln)
{
/*
@@ -2579,7 +2578,7 @@ neon_open(SMgrRelation reln)
/*
* neon_close() -- Close the specified relation, if it isn't closed already.
*/
static void
void
neon_close(SMgrRelation reln, ForkNumber forknum)
{
/*
@@ -2594,12 +2593,13 @@ neon_close(SMgrRelation reln, ForkNumber forknum)
/*
* neon_prefetch() -- Initiate asynchronous read of the specified block of a relation
*/
static bool
bool
neon_prefetch(SMgrRelation reln, ForkNumber forknum, BlockNumber blocknum,
int nblocks)
{
uint64 ring_index PG_USED_FOR_ASSERTS_ONLY;
BufferTag tag;
bool io_initiated = false;
switch (reln->smgr_relpersistence)
{
@@ -2623,6 +2623,7 @@ neon_prefetch(SMgrRelation reln, ForkNumber forknum, BlockNumber blocknum,
while (nblocks > 0)
{
int iterblocks = Min(nblocks, PG_IOV_MAX);
int seqlen = 0;
bits8 lfc_present[PG_IOV_MAX / 8];
memset(lfc_present, 0, sizeof(lfc_present));
@@ -2634,6 +2635,8 @@ neon_prefetch(SMgrRelation reln, ForkNumber forknum, BlockNumber blocknum,
continue;
}
io_initiated = true;
tag.blockNum = blocknum;
for (int i = 0; i < PG_IOV_MAX / 8; i++)
@@ -2656,7 +2659,7 @@ neon_prefetch(SMgrRelation reln, ForkNumber forknum, BlockNumber blocknum,
/*
* neon_prefetch() -- Initiate asynchronous read of the specified block of a relation
*/
static bool
bool
neon_prefetch(SMgrRelation reln, ForkNumber forknum, BlockNumber blocknum)
{
uint64 ring_index PG_USED_FOR_ASSERTS_ONLY;
@@ -2700,7 +2703,7 @@ neon_prefetch(SMgrRelation reln, ForkNumber forknum, BlockNumber blocknum)
* This accepts a range of blocks because flushing several pages at once is
* considerably more efficient than doing so individually.
*/
static void
void
neon_writeback(SMgrRelation reln, ForkNumber forknum,
BlockNumber blocknum, BlockNumber nblocks)
{
@@ -2921,10 +2924,10 @@ neon_read_at_lsn(NRelFileInfo rinfo, ForkNumber forkNum, BlockNumber blkno,
* neon_read() -- Read the specified block from a relation.
*/
#if PG_MAJORVERSION_NUM < 16
static void
void
neon_read(SMgrRelation reln, ForkNumber forkNum, BlockNumber blkno, char *buffer)
#else
static void
void
neon_read(SMgrRelation reln, ForkNumber forkNum, BlockNumber blkno, void *buffer)
#endif
{
@@ -3033,7 +3036,7 @@ neon_read(SMgrRelation reln, ForkNumber forkNum, BlockNumber blkno, void *buffer
#endif /* PG_MAJORVERSION_NUM <= 16 */
#if PG_MAJORVERSION_NUM >= 17
static void
void
neon_readv(SMgrRelation reln, ForkNumber forknum, BlockNumber blocknum,
void **buffers, BlockNumber nblocks)
{
@@ -3197,7 +3200,6 @@ hexdump_page(char *page)
}
#endif
#if PG_MAJORVERSION_NUM < 17
/*
* neon_write() -- Write the supplied block at the appropriate location.
*
@@ -3205,7 +3207,7 @@ hexdump_page(char *page)
* relation (ie, those before the current EOF). To extend a relation,
* use mdextend().
*/
static void
void
#if PG_MAJORVERSION_NUM < 16
neon_write(SMgrRelation reln, ForkNumber forknum, BlockNumber blocknum, char *buffer, bool skipFsync)
#else
@@ -3271,12 +3273,11 @@ neon_write(SMgrRelation reln, ForkNumber forknum, BlockNumber blocknum, const vo
#endif
#endif
}
#endif
#if PG_MAJORVERSION_NUM >= 17
static void
void
neon_writev(SMgrRelation reln, ForkNumber forknum, BlockNumber blkno,
const void **buffers, BlockNumber nblocks, bool skipFsync)
{
@@ -3326,7 +3327,7 @@ neon_writev(SMgrRelation reln, ForkNumber forknum, BlockNumber blkno,
/*
* neon_nblocks() -- Get the number of blocks stored in a relation.
*/
static BlockNumber
BlockNumber
neon_nblocks(SMgrRelation reln, ForkNumber forknum)
{
NeonResponse *resp;
@@ -3463,7 +3464,7 @@ neon_dbsize(Oid dbNode)
/*
* neon_truncate() -- Truncate relation to specified number of blocks.
*/
static void
void
neon_truncate(SMgrRelation reln, ForkNumber forknum, BlockNumber nblocks)
{
XLogRecPtr lsn;
@@ -3532,7 +3533,7 @@ neon_truncate(SMgrRelation reln, ForkNumber forknum, BlockNumber nblocks)
* crash before the next checkpoint syncs the newly-inactive segment, that
* segment may survive recovery, reintroducing unwanted data into the table.
*/
static void
void
neon_immedsync(SMgrRelation reln, ForkNumber forknum)
{
switch (reln->smgr_relpersistence)
@@ -3562,8 +3563,8 @@ neon_immedsync(SMgrRelation reln, ForkNumber forknum)
}
#if PG_MAJORVERSION_NUM >= 17
static void
neon_registersync(SMgrRelation reln, ForkNumber forknum)
void
neon_regisersync(SMgrRelation reln, ForkNumber forknum)
{
switch (reln->smgr_relpersistence)
{
@@ -3747,8 +3748,6 @@ neon_read_slru_segment(SMgrRelation reln, const char* path, int segno, void* buf
SlruKind kind;
int n_blocks;
shardno_t shard_no = 0; /* All SLRUs are at shard 0 */
NeonResponse *resp;
NeonGetSlruSegmentRequest request;
/*
* Compute a request LSN to use, similar to neon_get_request_lsns() but the
@@ -3787,7 +3786,8 @@ neon_read_slru_segment(SMgrRelation reln, const char* path, int segno, void* buf
else
return -1;
request = (NeonGetSlruSegmentRequest) {
NeonResponse *resp;
NeonGetSlruSegmentRequest request = {
.req.tag = T_NeonGetSlruSegmentRequest,
.req.lsn = request_lsn,
.req.not_modified_since = not_modified_since,
@@ -3894,7 +3894,7 @@ static const struct f_smgr neon_smgr =
.smgr_truncate = neon_truncate,
.smgr_immedsync = neon_immedsync,
#if PG_MAJORVERSION_NUM >= 17
.smgr_registersync = neon_registersync,
.smgr_registersync = neon_regisersync,
#endif
.smgr_start_unlogged_build = neon_start_unlogged_build,
.smgr_finish_unlogged_build_phase_1 = neon_finish_unlogged_build_phase_1,

View File

@@ -252,6 +252,8 @@ WalProposerPoll(WalProposer *wp)
/* timeout expired: poll state */
if (rc == 0 || TimeToReconnect(wp, now) <= 0)
{
TimestampTz now;
/*
* If no WAL was generated during timeout (and we have already
* collected the quorum), then send empty keepalive message
@@ -267,7 +269,8 @@ WalProposerPoll(WalProposer *wp)
now = wp->api.get_current_timestamp(wp);
for (int i = 0; i < wp->n_safekeepers; i++)
{
sk = &wp->safekeeper[i];
Safekeeper *sk = &wp->safekeeper[i];
if (TimestampDifferenceExceeds(sk->latestMsgReceivedAt, now,
wp->config->safekeeper_connection_timeout))
{
@@ -1077,7 +1080,7 @@ SendProposerElected(Safekeeper *sk)
ProposerElected msg;
TermHistory *th;
term_t lastCommonTerm;
int idx;
int i;
/* Now that we are ready to send it's a good moment to create WAL reader */
wp->api.wal_reader_allocate(sk);
@@ -1096,15 +1099,15 @@ SendProposerElected(Safekeeper *sk)
/* We must start somewhere. */
Assert(wp->propTermHistory.n_entries >= 1);
for (idx = 0; idx < Min(wp->propTermHistory.n_entries, th->n_entries); idx++)
for (i = 0; i < Min(wp->propTermHistory.n_entries, th->n_entries); i++)
{
if (wp->propTermHistory.entries[idx].term != th->entries[idx].term)
if (wp->propTermHistory.entries[i].term != th->entries[i].term)
break;
/* term must begin everywhere at the same point */
Assert(wp->propTermHistory.entries[idx].lsn == th->entries[idx].lsn);
Assert(wp->propTermHistory.entries[i].lsn == th->entries[i].lsn);
}
idx--; /* step back to the last common term */
if (idx < 0)
i--; /* step back to the last common term */
if (i < 0)
{
/* safekeeper is empty or no common point, start from the beginning */
sk->startStreamingAt = wp->propTermHistory.entries[0].lsn;
@@ -1125,14 +1128,14 @@ SendProposerElected(Safekeeper *sk)
* proposer, LSN it is currently writing, but then we just pick
* safekeeper pos as it obviously can't be higher.
*/
if (wp->propTermHistory.entries[idx].term == wp->propTerm)
if (wp->propTermHistory.entries[i].term == wp->propTerm)
{
sk->startStreamingAt = sk->voteResponse.flushLsn;
}
else
{
XLogRecPtr propEndLsn = wp->propTermHistory.entries[idx + 1].lsn;
XLogRecPtr skEndLsn = (idx + 1 < th->n_entries ? th->entries[idx + 1].lsn : sk->voteResponse.flushLsn);
XLogRecPtr propEndLsn = wp->propTermHistory.entries[i + 1].lsn;
XLogRecPtr skEndLsn = (i + 1 < th->n_entries ? th->entries[i + 1].lsn : sk->voteResponse.flushLsn);
sk->startStreamingAt = Min(propEndLsn, skEndLsn);
}
@@ -1146,7 +1149,7 @@ SendProposerElected(Safekeeper *sk)
msg.termHistory = &wp->propTermHistory;
msg.timelineStartLsn = wp->timelineStartLsn;
lastCommonTerm = idx >= 0 ? wp->propTermHistory.entries[idx].term : 0;
lastCommonTerm = i >= 0 ? wp->propTermHistory.entries[i].term : 0;
wp_log(LOG,
"sending elected msg to node " UINT64_FORMAT " term=" UINT64_FORMAT ", startStreamingAt=%X/%X (lastCommonTerm=" UINT64_FORMAT "), termHistory.n_entries=%u to %s:%s, timelineStartLsn=%X/%X",
sk->greetResponse.nodeId, msg.term, LSN_FORMAT_ARGS(msg.startStreamingAt), lastCommonTerm, msg.termHistory->n_entries, sk->host, sk->port, LSN_FORMAT_ARGS(msg.timelineStartLsn));
@@ -1638,7 +1641,7 @@ UpdateDonorShmem(WalProposer *wp)
* Process AppendResponse message from safekeeper.
*/
static void
HandleSafekeeperResponse(WalProposer *wp, Safekeeper *fromsk)
HandleSafekeeperResponse(WalProposer *wp, Safekeeper *sk)
{
XLogRecPtr candidateTruncateLsn;
XLogRecPtr newCommitLsn;
@@ -1657,7 +1660,7 @@ HandleSafekeeperResponse(WalProposer *wp, Safekeeper *fromsk)
* and WAL is committed by the quorum. BroadcastAppendRequest() should be
* called to notify safekeepers about the new commitLsn.
*/
wp->api.process_safekeeper_feedback(wp, fromsk);
wp->api.process_safekeeper_feedback(wp, sk);
/*
* Try to advance truncateLsn -- the last record flushed to all

View File

@@ -725,7 +725,7 @@ extern void WalProposerBroadcast(WalProposer *wp, XLogRecPtr startpos, XLogRecPt
extern void WalProposerPoll(WalProposer *wp);
extern void WalProposerFree(WalProposer *wp);
extern WalproposerShmemState *GetWalpropShmemState(void);
extern WalproposerShmemState *GetWalpropShmemState();
/*
* WaitEventSet API doesn't allow to remove socket, so walproposer_pg uses it to
@@ -745,7 +745,7 @@ extern TimeLineID walprop_pg_get_timeline_id(void);
* catch logging.
*/
#ifdef WALPROPOSER_LIB
extern void WalProposerLibLog(WalProposer *wp, int elevel, char *fmt,...) pg_attribute_printf(3, 4);
extern void WalProposerLibLog(WalProposer *wp, int elevel, char *fmt,...);
#define wp_log(elevel, fmt, ...) WalProposerLibLog(wp, elevel, fmt, ## __VA_ARGS__)
#else
#define wp_log(elevel, fmt, ...) elog(elevel, WP_LOG_PREFIX fmt, ## __VA_ARGS__)

View File

@@ -286,9 +286,6 @@ safekeepers_cmp(char *old, char *new)
static void
assign_neon_safekeepers(const char *newval, void *extra)
{
char *newval_copy;
char *oldval;
if (!am_walproposer)
return;
@@ -298,8 +295,8 @@ assign_neon_safekeepers(const char *newval, void *extra)
}
/* Copy values because we will modify them in split_safekeepers_list() */
newval_copy = pstrdup(newval);
oldval = pstrdup(wal_acceptors_list);
char *newval_copy = pstrdup(newval);
char *oldval = pstrdup(wal_acceptors_list);
/*
* TODO: restarting through FATAL is stupid and introduces 1s delay before
@@ -541,7 +538,7 @@ nwp_shmem_startup_hook(void)
}
WalproposerShmemState *
GetWalpropShmemState(void)
GetWalpropShmemState()
{
Assert(walprop_shared != NULL);
return walprop_shared;

View File

@@ -44,6 +44,27 @@ infobits_desc(StringInfo buf, uint8 infobits, const char *keyname)
appendStringInfoString(buf, "]");
}
static void
truncate_flags_desc(StringInfo buf, uint8 flags)
{
appendStringInfoString(buf, "flags: [");
if (flags & XLH_TRUNCATE_CASCADE)
appendStringInfoString(buf, "CASCADE, ");
if (flags & XLH_TRUNCATE_RESTART_SEQS)
appendStringInfoString(buf, "RESTART_SEQS, ");
if (buf->data[buf->len - 1] == ' ')
{
/* Truncate-away final unneeded ", " */
Assert(buf->data[buf->len - 2] == ',');
buf->len -= 2;
buf->data[buf->len] = '\0';
}
appendStringInfoString(buf, "]");
}
void
neon_rm_desc(StringInfo buf, XLogReaderState *record)
{

View File

@@ -136,7 +136,7 @@ static bool redo_block_filter(XLogReaderState *record, uint8 block_id);
static void GetPage(StringInfo input_message);
static void Ping(StringInfo input_message);
static ssize_t buffered_read(void *buf, size_t count);
static void CreateFakeSharedMemoryAndSemaphores(void);
static void CreateFakeSharedMemoryAndSemaphores();
static BufferTag target_redo_tag;
@@ -170,40 +170,6 @@ close_range_syscall(unsigned int start_fd, unsigned int count, unsigned int flag
return syscall(__NR_close_range, start_fd, count, flags);
}
static PgSeccompRule allowed_syscalls[] =
{
/* Hard requirements */
PG_SCMP_ALLOW(exit_group),
PG_SCMP_ALLOW(pselect6),
PG_SCMP_ALLOW(read),
PG_SCMP_ALLOW(select),
PG_SCMP_ALLOW(write),
/* Memory allocation */
PG_SCMP_ALLOW(brk),
#ifndef MALLOC_NO_MMAP
/* TODO: musl doesn't have mallopt */
PG_SCMP_ALLOW(mmap),
PG_SCMP_ALLOW(munmap),
#endif
/*
* getpid() is called on assertion failure, in ExceptionalCondition.
* It's not really needed, but seems pointless to hide it either. The
* system call unlikely to expose a kernel vulnerability, and the PID
* is stored in MyProcPid anyway.
*/
PG_SCMP_ALLOW(getpid),
/* Enable those for a proper shutdown. */
#if 0
PG_SCMP_ALLOW(munmap),
PG_SCMP_ALLOW(shmctl),
PG_SCMP_ALLOW(shmdt),
PG_SCMP_ALLOW(unlink), /* shm_unlink */
#endif
};
static void
enter_seccomp_mode(void)
{
@@ -217,12 +183,44 @@ enter_seccomp_mode(void)
(errcode(ERRCODE_SYSTEM_ERROR),
errmsg("seccomp: could not close files >= fd 3")));
PgSeccompRule syscalls[] =
{
/* Hard requirements */
PG_SCMP_ALLOW(exit_group),
PG_SCMP_ALLOW(pselect6),
PG_SCMP_ALLOW(read),
PG_SCMP_ALLOW(select),
PG_SCMP_ALLOW(write),
/* Memory allocation */
PG_SCMP_ALLOW(brk),
#ifndef MALLOC_NO_MMAP
/* TODO: musl doesn't have mallopt */
PG_SCMP_ALLOW(mmap),
PG_SCMP_ALLOW(munmap),
#endif
/*
* getpid() is called on assertion failure, in ExceptionalCondition.
* It's not really needed, but seems pointless to hide it either. The
* system call unlikely to expose a kernel vulnerability, and the PID
* is stored in MyProcPid anyway.
*/
PG_SCMP_ALLOW(getpid),
/* Enable those for a proper shutdown.
PG_SCMP_ALLOW(munmap),
PG_SCMP_ALLOW(shmctl),
PG_SCMP_ALLOW(shmdt),
PG_SCMP_ALLOW(unlink), // shm_unlink
*/
};
#ifdef MALLOC_NO_MMAP
/* Ask glibc not to use mmap() */
mallopt(M_MMAP_MAX, 0);
#endif
seccomp_load_rules(allowed_syscalls, lengthof(allowed_syscalls));
seccomp_load_rules(syscalls, lengthof(syscalls));
}
#endif /* HAVE_LIBSECCOMP */
@@ -451,7 +449,7 @@ WalRedoMain(int argc, char *argv[])
* half-initialized postgres.
*/
static void
CreateFakeSharedMemoryAndSemaphores(void)
CreateFakeSharedMemoryAndSemaphores()
{
PGShmemHeader *shim = NULL;
PGShmemHeader *hdr;

29
poetry.lock generated
View File

@@ -2095,7 +2095,6 @@ files = [
{file = "psycopg2_binary-2.9.9-cp311-cp311-win32.whl", hash = "sha256:dc4926288b2a3e9fd7b50dc6a1909a13bbdadfc67d93f3374d984e56f885579d"},
{file = "psycopg2_binary-2.9.9-cp311-cp311-win_amd64.whl", hash = "sha256:b76bedd166805480ab069612119ea636f5ab8f8771e640ae103e05a4aae3e417"},
{file = "psycopg2_binary-2.9.9-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:8532fd6e6e2dc57bcb3bc90b079c60de896d2128c5d9d6f24a63875a95a088cf"},
{file = "psycopg2_binary-2.9.9-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:b0605eaed3eb239e87df0d5e3c6489daae3f7388d455d0c0b4df899519c6a38d"},
{file = "psycopg2_binary-2.9.9-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:8f8544b092a29a6ddd72f3556a9fcf249ec412e10ad28be6a0c0d948924f2212"},
{file = "psycopg2_binary-2.9.9-cp312-cp312-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:2d423c8d8a3c82d08fe8af900ad5b613ce3632a1249fd6a223941d0735fce493"},
{file = "psycopg2_binary-2.9.9-cp312-cp312-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:2e5afae772c00980525f6d6ecf7cbca55676296b580c0e6abb407f15f3706996"},
@@ -2104,8 +2103,6 @@ files = [
{file = "psycopg2_binary-2.9.9-cp312-cp312-musllinux_1_1_i686.whl", hash = "sha256:cb16c65dcb648d0a43a2521f2f0a2300f40639f6f8c1ecbc662141e4e3e1ee07"},
{file = "psycopg2_binary-2.9.9-cp312-cp312-musllinux_1_1_ppc64le.whl", hash = "sha256:911dda9c487075abd54e644ccdf5e5c16773470a6a5d3826fda76699410066fb"},
{file = "psycopg2_binary-2.9.9-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:57fede879f08d23c85140a360c6a77709113efd1c993923c59fde17aa27599fe"},
{file = "psycopg2_binary-2.9.9-cp312-cp312-win32.whl", hash = "sha256:64cf30263844fa208851ebb13b0732ce674d8ec6a0c86a4e160495d299ba3c93"},
{file = "psycopg2_binary-2.9.9-cp312-cp312-win_amd64.whl", hash = "sha256:81ff62668af011f9a48787564ab7eded4e9fb17a4a6a74af5ffa6a457400d2ab"},
{file = "psycopg2_binary-2.9.9-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:2293b001e319ab0d869d660a704942c9e2cce19745262a8aba2115ef41a0a42a"},
{file = "psycopg2_binary-2.9.9-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:03ef7df18daf2c4c07e2695e8cfd5ee7f748a1d54d802330985a78d2a5a6dca9"},
{file = "psycopg2_binary-2.9.9-cp37-cp37m-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:0a602ea5aff39bb9fac6308e9c9d82b9a35c2bf288e184a816002c9fae930b77"},
@@ -2587,7 +2584,6 @@ files = [
{file = "PyYAML-6.0.1-cp311-cp311-win_amd64.whl", hash = "sha256:bf07ee2fef7014951eeb99f56f39c9bb4af143d8aa3c21b1677805985307da34"},
{file = "PyYAML-6.0.1-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:855fb52b0dc35af121542a76b9a84f8d1cd886ea97c84703eaa6d88e37a2ad28"},
{file = "PyYAML-6.0.1-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:40df9b996c2b73138957fe23a16a4f0ba614f4c0efce1e9406a184b6d07fa3a9"},
{file = "PyYAML-6.0.1-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a08c6f0fe150303c1c6b71ebcd7213c2858041a7e01975da3a99aed1e7a378ef"},
{file = "PyYAML-6.0.1-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:6c22bec3fbe2524cde73d7ada88f6566758a8f7227bfbf93a408a9d86bcc12a0"},
{file = "PyYAML-6.0.1-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:8d4e9c88387b0f5c7d5f281e55304de64cf7f9c0021a3525bd3b1c542da3b0e4"},
{file = "PyYAML-6.0.1-cp312-cp312-win32.whl", hash = "sha256:d483d2cdf104e7c9fa60c544d92981f12ad66a457afae824d146093b8c294c54"},
@@ -2733,22 +2729,21 @@ use-chardet-on-py3 = ["chardet (>=3.0.2,<6)"]
[[package]]
name = "responses"
version = "0.25.3"
version = "0.21.0"
description = "A utility library for mocking out the `requests` Python library."
optional = false
python-versions = ">=3.8"
python-versions = ">=3.7"
files = [
{file = "responses-0.25.3-py3-none-any.whl", hash = "sha256:521efcbc82081ab8daa588e08f7e8a64ce79b91c39f6e62199b19159bea7dbcb"},
{file = "responses-0.25.3.tar.gz", hash = "sha256:617b9247abd9ae28313d57a75880422d55ec63c29d33d629697590a034358dba"},
{file = "responses-0.21.0-py3-none-any.whl", hash = "sha256:2dcc863ba63963c0c3d9ee3fa9507cbe36b7d7b0fccb4f0bdfd9e96c539b1487"},
{file = "responses-0.21.0.tar.gz", hash = "sha256:b82502eb5f09a0289d8e209e7bad71ef3978334f56d09b444253d5ad67bf5253"},
]
[package.dependencies]
pyyaml = "*"
requests = ">=2.30.0,<3.0"
urllib3 = ">=1.25.10,<3.0"
requests = ">=2.0,<3.0"
urllib3 = ">=1.25.10"
[package.extras]
tests = ["coverage (>=6.0.0)", "flake8", "mypy", "pytest (>=7.0.0)", "pytest-asyncio", "pytest-cov", "pytest-httpserver", "tomli", "tomli-w", "types-PyYAML", "types-requests"]
tests = ["coverage (>=6.0.0)", "flake8", "mypy", "pytest (>=7.0.0)", "pytest-asyncio", "pytest-cov", "pytest-localserver", "types-mock", "types-requests"]
[[package]]
name = "rfc3339-validator"
@@ -3142,16 +3137,6 @@ files = [
{file = "wrapt-1.14.1-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:8ad85f7f4e20964db4daadcab70b47ab05c7c1cf2a7c1e51087bfaa83831854c"},
{file = "wrapt-1.14.1-cp310-cp310-win32.whl", hash = "sha256:a9a52172be0b5aae932bef82a79ec0a0ce87288c7d132946d645eba03f0ad8a8"},
{file = "wrapt-1.14.1-cp310-cp310-win_amd64.whl", hash = "sha256:6d323e1554b3d22cfc03cd3243b5bb815a51f5249fdcbb86fda4bf62bab9e164"},
{file = "wrapt-1.14.1-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:ecee4132c6cd2ce5308e21672015ddfed1ff975ad0ac8d27168ea82e71413f55"},
{file = "wrapt-1.14.1-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:2020f391008ef874c6d9e208b24f28e31bcb85ccff4f335f15a3251d222b92d9"},
{file = "wrapt-1.14.1-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:2feecf86e1f7a86517cab34ae6c2f081fd2d0dac860cb0c0ded96d799d20b335"},
{file = "wrapt-1.14.1-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:240b1686f38ae665d1b15475966fe0472f78e71b1b4903c143a842659c8e4cb9"},
{file = "wrapt-1.14.1-cp311-cp311-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:a9008dad07d71f68487c91e96579c8567c98ca4c3881b9b113bc7b33e9fd78b8"},
{file = "wrapt-1.14.1-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:6447e9f3ba72f8e2b985a1da758767698efa72723d5b59accefd716e9e8272bf"},
{file = "wrapt-1.14.1-cp311-cp311-musllinux_1_1_i686.whl", hash = "sha256:acae32e13a4153809db37405f5eba5bac5fbe2e2ba61ab227926a22901051c0a"},
{file = "wrapt-1.14.1-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:49ef582b7a1152ae2766557f0550a9fcbf7bbd76f43fbdc94dd3bf07cc7168be"},
{file = "wrapt-1.14.1-cp311-cp311-win32.whl", hash = "sha256:358fe87cc899c6bb0ddc185bf3dbfa4ba646f05b1b0b9b5a27c2cb92c2cea204"},
{file = "wrapt-1.14.1-cp311-cp311-win_amd64.whl", hash = "sha256:26046cd03936ae745a502abf44dac702a5e6880b2b01c29aea8ddf3353b68224"},
{file = "wrapt-1.14.1-cp35-cp35m-manylinux1_i686.whl", hash = "sha256:43ca3bbbe97af00f49efb06e352eae40434ca9d915906f77def219b88e85d907"},
{file = "wrapt-1.14.1-cp35-cp35m-manylinux1_x86_64.whl", hash = "sha256:6b1a564e6cb69922c7fe3a678b9f9a3c54e72b469875aa8018f18b4d1dd1adf3"},
{file = "wrapt-1.14.1-cp35-cp35m-manylinux2010_i686.whl", hash = "sha256:00b6d4ea20a906c0ca56d84f93065b398ab74b927a7a3dbd470f6fc503f95dc3"},

View File

@@ -1,12 +1,11 @@
#!/usr/bin/env python3
from __future__ import annotations
import argparse
import enum
import os
import subprocess
import sys
from typing import List
@enum.unique
@@ -56,12 +55,12 @@ def mypy() -> str:
return "poetry run mypy"
def get_commit_files() -> list[str]:
def get_commit_files() -> List[str]:
files = subprocess.check_output("git diff --cached --name-only --diff-filter=ACM".split())
return files.decode().splitlines()
def check(name: str, suffix: str, cmd: str, changed_files: list[str], no_color: bool = False):
def check(name: str, suffix: str, cmd: str, changed_files: List[str], no_color: bool = False):
print(f"Checking: {name} ", end="")
applicable_files = list(filter(lambda fname: fname.strip().endswith(suffix), changed_files))
if not applicable_files:

View File

@@ -39,7 +39,7 @@ http.workspace = true
humantime.workspace = true
humantime-serde.workspace = true
hyper0.workspace = true
hyper = { workspace = true, features = ["server", "http1", "http2"] }
hyper1 = { package = "hyper", version = "1.2", features = ["server"] }
hyper-util = { version = "0.1", features = ["server", "http1", "http2", "tokio"] }
http-body-util = { version = "0.1" }
indexmap.workspace = true
@@ -77,7 +77,7 @@ subtle.workspace = true
thiserror.workspace = true
tikv-jemallocator.workspace = true
tikv-jemalloc-ctl = { workspace = true, features = ["use_std"] }
tokio-postgres = { workspace = true, features = ["with-serde_json-1"] }
tokio-postgres.workspace = true
tokio-postgres-rustls.workspace = true
tokio-rustls.workspace = true
tokio-util.workspace = true
@@ -101,7 +101,7 @@ jose-jwa = "0.1.2"
jose-jwk = { version = "0.1.2", features = ["p256", "p384", "rsa"] }
signature = "2"
ecdsa = "0.16"
p256 = { version = "0.13", features = ["jwk"] }
p256 = "0.13"
rsa = "0.9"
workspace_hack.workspace = true

View File

@@ -17,8 +17,6 @@ use crate::{
RoleName,
};
use super::ComputeCredentialKeys;
// TODO(conrad): make these configurable.
const CLOCK_SKEW_LEEWAY: Duration = Duration::from_secs(30);
const MIN_RENEW: Duration = Duration::from_secs(30);
@@ -243,7 +241,7 @@ impl JwkCacheEntryLock {
endpoint: EndpointId,
role_name: &RoleName,
fetch: &F,
) -> Result<ComputeCredentialKeys, anyhow::Error> {
) -> Result<(), anyhow::Error> {
// JWT compact form is defined to be
// <B64(Header)> || . || <B64(Payload)> || . || <B64(Signature)>
// where Signature = alg(<B64(Header)> || . || <B64(Payload)>);
@@ -302,9 +300,9 @@ impl JwkCacheEntryLock {
key => bail!("unsupported key type {key:?}"),
};
let payloadb = base64::decode_config(payload, base64::URL_SAFE_NO_PAD)
let payload = base64::decode_config(payload, base64::URL_SAFE_NO_PAD)
.context("Provided authentication token is not a valid JWT encoding")?;
let payload = serde_json::from_slice::<JwtPayload<'_>>(&payloadb)
let payload = serde_json::from_slice::<JwtPayload<'_>>(&payload)
.context("Provided authentication token is not a valid JWT encoding")?;
tracing::debug!(?payload, "JWT signature valid with claims");
@@ -329,7 +327,7 @@ impl JwkCacheEntryLock {
);
}
Ok(ComputeCredentialKeys::JwtPayload(payloadb))
Ok(())
}
}
@@ -341,7 +339,7 @@ impl JwkCache {
role_name: &RoleName,
fetch: &F,
jwt: &str,
) -> Result<ComputeCredentialKeys, anyhow::Error> {
) -> Result<(), anyhow::Error> {
// try with just a read lock first
let key = (endpoint.clone(), role_name.clone());
let entry = self.map.get(&key).as_deref().map(Arc::clone);
@@ -573,7 +571,7 @@ mod tests {
use bytes::Bytes;
use http::Response;
use http_body_util::Full;
use hyper::service::service_fn;
use hyper1::service::service_fn;
use hyper_util::rt::TokioIo;
use rand::rngs::OsRng;
use rsa::pkcs8::DecodePrivateKey;
@@ -738,7 +736,7 @@ X0n5X2/pBLJzxZc62ccvZYVnctBiFs6HbSnxpuMQCfkt/BcR/ttIepBQQIW86wHL
});
let listener = TcpListener::bind("0.0.0.0:0").await.unwrap();
let server = hyper::server::conn::http1::Builder::new();
let server = hyper1::server::conn::http1::Builder::new();
let addr = listener.local_addr().unwrap();
tokio::spawn(async move {
loop {

View File

@@ -175,12 +175,10 @@ impl ComputeUserInfo {
}
}
#[cfg_attr(test, derive(Debug))]
pub(crate) enum ComputeCredentialKeys {
#[cfg(any(test, feature = "testing"))]
Password(Vec<u8>),
AuthKeys(AuthKeys),
JwtPayload(Vec<u8>),
None,
}

View File

@@ -309,7 +309,7 @@ impl NodeInfo {
#[cfg(any(test, feature = "testing"))]
ComputeCredentialKeys::Password(password) => self.config.password(password),
ComputeCredentialKeys::AuthKeys(auth_keys) => self.config.auth_keys(*auth_keys),
ComputeCredentialKeys::JwtPayload(_) | ComputeCredentialKeys::None => &mut self.config,
ComputeCredentialKeys::None => &mut self.config,
};
}
}

View File

@@ -1,5 +1,5 @@
use anyhow::{anyhow, bail};
use hyper0::{header::CONTENT_TYPE, Body, Request, Response, StatusCode};
use hyper::{header::CONTENT_TYPE, Body, Request, Response, StatusCode};
use measured::{text::BufferedTextEncoder, MetricGroup};
use metrics::NeonMetrics;
use std::{
@@ -21,7 +21,7 @@ async fn status_handler(_: Request<Body>) -> Result<Response<Body>, ApiError> {
json_response(StatusCode::OK, "")
}
fn make_router(metrics: AppMetrics) -> RouterBuilder<hyper0::Body, ApiError> {
fn make_router(metrics: AppMetrics) -> RouterBuilder<hyper::Body, ApiError> {
let state = Arc::new(Mutex::new(PrometheusHandler {
encoder: BufferedTextEncoder::new(),
metrics,
@@ -45,7 +45,7 @@ pub async fn task_main(
let service = || RouterService::new(make_router(metrics).build()?);
hyper0::Server::from_tcp(http_listener)?
hyper::Server::from_tcp(http_listener)?
.serve(service().map_err(|e| anyhow!(e))?)
.await?;

View File

@@ -9,7 +9,7 @@ use std::time::Duration;
use anyhow::bail;
use bytes::Bytes;
use http_body_util::BodyExt;
use hyper::body::Body;
use hyper1::body::Body;
use serde::de::DeserializeOwned;
pub(crate) use reqwest::{Request, Response};

View File

@@ -90,6 +90,8 @@ use tokio::task::JoinError;
use tokio_util::sync::CancellationToken;
use tracing::warn;
extern crate hyper0 as hyper;
pub mod auth;
pub mod cache;
pub mod cancellation;

View File

@@ -7,7 +7,7 @@ use crate::metrics::{
WakeupFailureKind,
};
use crate::proxy::retry::{retry_after, should_retry};
use hyper::StatusCode;
use hyper1::StatusCode;
use tracing::{error, info, warn};
use super::connect_compute::ComputeConnectBackend;

View File

@@ -3,12 +3,10 @@ use std::{io, sync::Arc, time::Duration};
use async_trait::async_trait;
use hyper_util::rt::{TokioExecutor, TokioIo, TokioTimer};
use tokio::net::{lookup_host, TcpStream};
use tokio_postgres::types::ToSql;
use tracing::{debug, field::display, info};
use tracing::{field::display, info};
use crate::{
auth::{
self,
backend::{local::StaticAuthRules, ComputeCredentials, ComputeUserInfo},
check_peer_addr_is_in_list, AuthError,
},
@@ -34,12 +32,10 @@ use crate::{
use super::{
conn_pool::{poll_client, Client, ConnInfo, GlobalConnPool},
http_conn_pool::{self, poll_http2_client},
local_conn_pool::{self, LocalClient, LocalConnPool},
};
pub(crate) struct PoolingBackend {
pub(crate) http_conn_pool: Arc<super::http_conn_pool::GlobalConnPool>,
pub(crate) local_pool: Arc<LocalConnPool<tokio_postgres::Client>>,
pub(crate) pool: Arc<GlobalConnPool<tokio_postgres::Client>>,
pub(crate) config: &'static ProxyConfig,
pub(crate) endpoint_rate_limiter: Arc<EndpointRateLimiter>,
@@ -116,7 +112,7 @@ impl PoolingBackend {
config: &AuthenticationConfig,
user_info: &ComputeUserInfo,
jwt: String,
) -> Result<ComputeCredentials, AuthError> {
) -> Result<(), AuthError> {
match &self.config.auth_backend {
crate::auth::Backend::ControlPlane(console, ()) => {
config
@@ -131,16 +127,13 @@ impl PoolingBackend {
.await
.map_err(|e| AuthError::auth_failed(e.to_string()))?;
Ok(ComputeCredentials {
info: user_info.clone(),
keys: crate::auth::backend::ComputeCredentialKeys::None,
})
Ok(())
}
crate::auth::Backend::ConsoleRedirect(_, ()) => Err(AuthError::auth_failed(
"JWT login over web auth proxy is not supported",
)),
crate::auth::Backend::Local(_) => {
let keys = config
config
.jwks_cache
.check_jwt(
ctx,
@@ -152,10 +145,8 @@ impl PoolingBackend {
.await
.map_err(|e| AuthError::auth_failed(e.to_string()))?;
Ok(ComputeCredentials {
info: user_info.clone(),
keys,
})
// todo: rewrite JWT signature with key shared somehow between local proxy and postgres
Ok(())
}
}
}
@@ -240,77 +231,6 @@ impl PoolingBackend {
)
.await
}
/// Connect to postgres over localhost.
///
/// We expect postgres to be started here, so we won't do any retries.
///
/// # Panics
///
/// Panics if called with a non-local_proxy backend.
#[tracing::instrument(fields(pid = tracing::field::Empty), skip_all)]
pub(crate) async fn connect_to_local_postgres(
&self,
ctx: &RequestMonitoring,
conn_info: ConnInfo,
) -> Result<LocalClient<tokio_postgres::Client>, HttpConnError> {
if let Some(client) = self.local_pool.get(ctx, &conn_info)? {
return Ok(client);
}
let conn_id = uuid::Uuid::new_v4();
tracing::Span::current().record("conn_id", display(conn_id));
info!(%conn_id, "local_pool: opening a new connection '{conn_info}'");
let mut node_info = match &self.config.auth_backend {
auth::Backend::ControlPlane(_, ()) | auth::Backend::ConsoleRedirect(_, ()) => {
unreachable!("only local_proxy can connect to local postgres")
}
auth::Backend::Local(local) => local.node_info.clone(),
};
let config = node_info
.config
.user(&conn_info.user_info.user)
.dbname(&conn_info.dbname);
let pause = ctx.latency_timer_pause(crate::metrics::Waiting::Compute);
let (client, connection) = config.connect(tokio_postgres::NoTls).await?;
drop(pause);
tracing::Span::current().record("pid", tracing::field::display(client.get_process_id()));
let handle = local_conn_pool::poll_client(
self.local_pool.clone(),
ctx,
conn_info,
client,
connection,
conn_id,
node_info.aux.clone(),
);
let kid = handle.get_client().get_process_id() as i64;
let jwk = p256::PublicKey::from(handle.key().verifying_key()).to_jwk();
debug!(kid, ?jwk, "setting up backend session state");
// initiates the auth session
handle
.get_client()
.query(
"select auth.init($1, $2);",
&[
&kid as &(dyn ToSql + Sync),
&tokio_postgres::types::Json(jwk),
],
)
.await?;
info!(?kid, "backend session state init");
Ok(handle)
}
}
#[derive(Debug, thiserror::Error)]
@@ -321,8 +241,6 @@ pub(crate) enum HttpConnError {
PostgresConnectionError(#[from] tokio_postgres::Error),
#[error("could not connection to local-proxy in compute")]
LocalProxyConnectionError(#[from] LocalProxyConnError),
#[error("could not parse JWT payload")]
JwtPayloadError(serde_json::Error),
#[error("could not get auth info")]
GetAuthInfo(#[from] GetAuthInfoError),
@@ -339,7 +257,7 @@ pub(crate) enum LocalProxyConnError {
#[error("error with connection to local-proxy")]
Io(#[source] std::io::Error),
#[error("could not establish h2 connection")]
H2(#[from] hyper::Error),
H2(#[from] hyper1::Error),
}
impl ReportableError for HttpConnError {
@@ -348,7 +266,6 @@ impl ReportableError for HttpConnError {
HttpConnError::ConnectionClosedAbruptly(_) => ErrorKind::Compute,
HttpConnError::PostgresConnectionError(p) => p.get_error_kind(),
HttpConnError::LocalProxyConnectionError(_) => ErrorKind::Compute,
HttpConnError::JwtPayloadError(_) => ErrorKind::User,
HttpConnError::GetAuthInfo(a) => a.get_error_kind(),
HttpConnError::AuthError(a) => a.get_error_kind(),
HttpConnError::WakeCompute(w) => w.get_error_kind(),
@@ -363,7 +280,6 @@ impl UserFacingError for HttpConnError {
HttpConnError::ConnectionClosedAbruptly(_) => self.to_string(),
HttpConnError::PostgresConnectionError(p) => p.to_string(),
HttpConnError::LocalProxyConnectionError(p) => p.to_string(),
HttpConnError::JwtPayloadError(p) => p.to_string(),
HttpConnError::GetAuthInfo(c) => c.to_string_client(),
HttpConnError::AuthError(c) => c.to_string_client(),
HttpConnError::WakeCompute(c) => c.to_string_client(),
@@ -380,7 +296,6 @@ impl CouldRetry for HttpConnError {
HttpConnError::PostgresConnectionError(e) => e.could_retry(),
HttpConnError::LocalProxyConnectionError(e) => e.could_retry(),
HttpConnError::ConnectionClosedAbruptly(_) => false,
HttpConnError::JwtPayloadError(_) => false,
HttpConnError::GetAuthInfo(_) => false,
HttpConnError::AuthError(_) => false,
HttpConnError::WakeCompute(_) => false,
@@ -566,7 +481,7 @@ async fn connect_http2(
};
};
let (client, connection) = hyper::client::conn::http2::Builder::new(TokioExecutor::new())
let (client, connection) = hyper1::client::conn::http2::Builder::new(TokioExecutor::new())
.timer(TokioTimer::new())
.keep_alive_interval(Duration::from_secs(20))
.keep_alive_while_idle(true)

View File

@@ -1,5 +1,5 @@
use dashmap::DashMap;
use hyper::client::conn::http2;
use hyper1::client::conn::http2;
use hyper_util::rt::{TokioExecutor, TokioIo};
use parking_lot::RwLock;
use rand::Rng;
@@ -18,9 +18,9 @@ use tracing::{info, info_span, Instrument};
use super::conn_pool::ConnInfo;
pub(crate) type Send = http2::SendRequest<hyper::body::Incoming>;
pub(crate) type Send = http2::SendRequest<hyper1::body::Incoming>;
pub(crate) type Connect =
http2::Connection<TokioIo<TcpStream>, hyper::body::Incoming, TokioExecutor>;
http2::Connection<TokioIo<TcpStream>, hyper1::body::Incoming, TokioExecutor>;
#[derive(Clone)]
struct ConnPoolEntry {

View File

@@ -11,7 +11,7 @@ use serde::Serialize;
use utils::http::error::ApiError;
/// Like [`ApiError::into_response`]
pub(crate) fn api_error_into_response(this: ApiError) -> Response<BoxBody<Bytes, hyper::Error>> {
pub(crate) fn api_error_into_response(this: ApiError) -> Response<BoxBody<Bytes, hyper1::Error>> {
match this {
ApiError::BadRequest(err) => HttpErrorBody::response_from_msg_and_status(
format!("{err:#?}"), // use debug printing so that we give the cause
@@ -67,12 +67,12 @@ impl HttpErrorBody {
fn response_from_msg_and_status(
msg: String,
status: StatusCode,
) -> Response<BoxBody<Bytes, hyper::Error>> {
) -> Response<BoxBody<Bytes, hyper1::Error>> {
HttpErrorBody { msg }.to_response(status)
}
/// Same as [`utils::http::error::HttpErrorBody::to_response`]
fn to_response(&self, status: StatusCode) -> Response<BoxBody<Bytes, hyper::Error>> {
fn to_response(&self, status: StatusCode) -> Response<BoxBody<Bytes, hyper1::Error>> {
Response::builder()
.status(status)
.header(http::header::CONTENT_TYPE, "application/json")
@@ -90,7 +90,7 @@ impl HttpErrorBody {
pub(crate) fn json_response<T: Serialize>(
status: StatusCode,
data: T,
) -> Result<Response<BoxBody<Bytes, hyper::Error>>, ApiError> {
) -> Result<Response<BoxBody<Bytes, hyper1::Error>>, ApiError> {
let json = serde_json::to_string(&data)
.context("Failed to serialize JSON response")
.map_err(ApiError::InternalServerError)?;

View File

@@ -1,544 +0,0 @@
use futures::{future::poll_fn, Future};
use jose_jwk::jose_b64::base64ct::{Base64UrlUnpadded, Encoding};
use p256::ecdsa::{Signature, SigningKey};
use parking_lot::RwLock;
use rand::rngs::OsRng;
use serde_json::Value;
use signature::Signer;
use std::task::{ready, Poll};
use std::{collections::HashMap, pin::pin, sync::Arc, sync::Weak, time::Duration};
use tokio::time::Instant;
use tokio_postgres::tls::NoTlsStream;
use tokio_postgres::types::ToSql;
use tokio_postgres::{AsyncMessage, ReadyForQueryStatus, Socket};
use tokio_util::sync::CancellationToken;
use typed_json::json;
use crate::control_plane::messages::{ColdStartInfo, MetricsAuxInfo};
use crate::metrics::Metrics;
use crate::usage_metrics::{Ids, MetricCounter, USAGE_METRICS};
use crate::{context::RequestMonitoring, DbName, RoleName};
use tracing::{debug, error, warn, Span};
use tracing::{info, info_span, Instrument};
use super::backend::HttpConnError;
use super::conn_pool::{ClientInnerExt, ConnInfo};
struct ConnPoolEntry<C: ClientInnerExt> {
conn: ClientInner<C>,
_last_access: std::time::Instant,
}
// /// key id for the pg_session_jwt state
// static PG_SESSION_JWT_KID: AtomicU64 = AtomicU64::new(1);
// Per-endpoint connection pool, (dbname, username) -> DbUserConnPool
// Number of open connections is limited by the `max_conns_per_endpoint`.
pub(crate) struct EndpointConnPool<C: ClientInnerExt> {
pools: HashMap<(DbName, RoleName), DbUserConnPool<C>>,
total_conns: usize,
max_conns: usize,
global_pool_size_max_conns: usize,
}
impl<C: ClientInnerExt> EndpointConnPool<C> {
fn get_conn_entry(&mut self, db_user: (DbName, RoleName)) -> Option<ConnPoolEntry<C>> {
let Self {
pools, total_conns, ..
} = self;
pools
.get_mut(&db_user)
.and_then(|pool_entries| pool_entries.get_conn_entry(total_conns))
}
fn remove_client(&mut self, db_user: (DbName, RoleName), conn_id: uuid::Uuid) -> bool {
let Self {
pools, total_conns, ..
} = self;
if let Some(pool) = pools.get_mut(&db_user) {
let old_len = pool.conns.len();
pool.conns.retain(|conn| conn.conn.conn_id != conn_id);
let new_len = pool.conns.len();
let removed = old_len - new_len;
if removed > 0 {
Metrics::get()
.proxy
.http_pool_opened_connections
.get_metric()
.dec_by(removed as i64);
}
*total_conns -= removed;
removed > 0
} else {
false
}
}
fn put(pool: &RwLock<Self>, conn_info: &ConnInfo, client: ClientInner<C>) {
let conn_id = client.conn_id;
if client.is_closed() {
info!(%conn_id, "local_pool: throwing away connection '{conn_info}' because connection is closed");
return;
}
let global_max_conn = pool.read().global_pool_size_max_conns;
if pool.read().total_conns >= global_max_conn {
info!(%conn_id, "local_pool: throwing away connection '{conn_info}' because pool is full");
return;
}
// return connection to the pool
let mut returned = false;
let mut per_db_size = 0;
let total_conns = {
let mut pool = pool.write();
if pool.total_conns < pool.max_conns {
let pool_entries = pool.pools.entry(conn_info.db_and_user()).or_default();
pool_entries.conns.push(ConnPoolEntry {
conn: client,
_last_access: std::time::Instant::now(),
});
returned = true;
per_db_size = pool_entries.conns.len();
pool.total_conns += 1;
Metrics::get()
.proxy
.http_pool_opened_connections
.get_metric()
.inc();
}
pool.total_conns
};
// do logging outside of the mutex
if returned {
info!(%conn_id, "local_pool: returning connection '{conn_info}' back to the pool, total_conns={total_conns}, for this (db, user)={per_db_size}");
} else {
info!(%conn_id, "local_pool: throwing away connection '{conn_info}' because pool is full, total_conns={total_conns}");
}
}
}
impl<C: ClientInnerExt> Drop for EndpointConnPool<C> {
fn drop(&mut self) {
if self.total_conns > 0 {
Metrics::get()
.proxy
.http_pool_opened_connections
.get_metric()
.dec_by(self.total_conns as i64);
}
}
}
pub(crate) struct DbUserConnPool<C: ClientInnerExt> {
conns: Vec<ConnPoolEntry<C>>,
}
impl<C: ClientInnerExt> Default for DbUserConnPool<C> {
fn default() -> Self {
Self { conns: Vec::new() }
}
}
impl<C: ClientInnerExt> DbUserConnPool<C> {
fn clear_closed_clients(&mut self, conns: &mut usize) -> usize {
let old_len = self.conns.len();
self.conns.retain(|conn| !conn.conn.is_closed());
let new_len = self.conns.len();
let removed = old_len - new_len;
*conns -= removed;
removed
}
fn get_conn_entry(&mut self, conns: &mut usize) -> Option<ConnPoolEntry<C>> {
let mut removed = self.clear_closed_clients(conns);
let conn = self.conns.pop();
if conn.is_some() {
*conns -= 1;
removed += 1;
}
Metrics::get()
.proxy
.http_pool_opened_connections
.get_metric()
.dec_by(removed as i64);
conn
}
}
pub(crate) struct LocalConnPool<C: ClientInnerExt> {
global_pool: RwLock<EndpointConnPool<C>>,
config: &'static crate::config::HttpConfig,
}
impl<C: ClientInnerExt> LocalConnPool<C> {
pub(crate) fn new(config: &'static crate::config::HttpConfig) -> Arc<Self> {
Arc::new(Self {
global_pool: RwLock::new(EndpointConnPool {
pools: HashMap::new(),
total_conns: 0,
max_conns: config.pool_options.max_conns_per_endpoint,
global_pool_size_max_conns: config.pool_options.max_total_conns,
}),
config,
})
}
pub(crate) fn get_idle_timeout(&self) -> Duration {
self.config.pool_options.idle_timeout
}
// pub(crate) fn shutdown(&self) {
// let mut pool = self.global_pool.write();
// pool.pools.clear();
// pool.total_conns = 0;
// }
pub(crate) fn get(
self: &Arc<Self>,
ctx: &RequestMonitoring,
conn_info: &ConnInfo,
) -> Result<Option<LocalClient<C>>, HttpConnError> {
let mut client: Option<ClientInner<C>> = None;
if let Some(entry) = self
.global_pool
.write()
.get_conn_entry(conn_info.db_and_user())
{
client = Some(entry.conn);
}
// ok return cached connection if found and establish a new one otherwise
if let Some(client) = client {
if client.is_closed() {
info!("local_pool: cached connection '{conn_info}' is closed, opening a new one");
return Ok(None);
}
tracing::Span::current().record("conn_id", tracing::field::display(client.conn_id));
tracing::Span::current().record(
"pid",
tracing::field::display(client.inner.get_process_id()),
);
info!(
cold_start_info = ColdStartInfo::HttpPoolHit.as_str(),
"local_pool: reusing connection '{conn_info}'"
);
client.session.send(ctx.session_id())?;
ctx.set_cold_start_info(ColdStartInfo::HttpPoolHit);
ctx.success();
return Ok(Some(LocalClient::new(
client,
conn_info.clone(),
Arc::downgrade(self),
)));
}
Ok(None)
}
}
pub(crate) fn poll_client(
global_pool: Arc<LocalConnPool<tokio_postgres::Client>>,
ctx: &RequestMonitoring,
conn_info: ConnInfo,
client: tokio_postgres::Client,
mut connection: tokio_postgres::Connection<Socket, NoTlsStream>,
conn_id: uuid::Uuid,
aux: MetricsAuxInfo,
) -> LocalClient<tokio_postgres::Client> {
let conn_gauge = Metrics::get().proxy.db_connections.guard(ctx.protocol());
let mut session_id = ctx.session_id();
let (tx, mut rx) = tokio::sync::watch::channel(session_id);
let span = info_span!(parent: None, "connection", %conn_id);
let cold_start_info = ctx.cold_start_info();
span.in_scope(|| {
info!(cold_start_info = cold_start_info.as_str(), %conn_info, %session_id, "new connection");
});
let pool = Arc::downgrade(&global_pool);
let pool_clone = pool.clone();
let db_user = conn_info.db_and_user();
let idle = global_pool.get_idle_timeout();
let cancel = CancellationToken::new();
let cancelled = cancel.clone().cancelled_owned();
tokio::spawn(
async move {
let _conn_gauge = conn_gauge;
let mut idle_timeout = pin!(tokio::time::sleep(idle));
let mut cancelled = pin!(cancelled);
poll_fn(move |cx| {
if cancelled.as_mut().poll(cx).is_ready() {
info!("connection dropped");
return Poll::Ready(())
}
match rx.has_changed() {
Ok(true) => {
session_id = *rx.borrow_and_update();
info!(%session_id, "changed session");
idle_timeout.as_mut().reset(Instant::now() + idle);
}
Err(_) => {
info!("connection dropped");
return Poll::Ready(())
}
_ => {}
}
// 5 minute idle connection timeout
if idle_timeout.as_mut().poll(cx).is_ready() {
idle_timeout.as_mut().reset(Instant::now() + idle);
info!("connection idle");
if let Some(pool) = pool.clone().upgrade() {
// remove client from pool - should close the connection if it's idle.
// does nothing if the client is currently checked-out and in-use
if pool.global_pool.write().remove_client(db_user.clone(), conn_id) {
info!("idle connection removed");
}
}
}
loop {
let message = ready!(connection.poll_message(cx));
match message {
Some(Ok(AsyncMessage::Notice(notice))) => {
info!(%session_id, "notice: {}", notice);
}
Some(Ok(AsyncMessage::Notification(notif))) => {
warn!(%session_id, pid = notif.process_id(), channel = notif.channel(), "notification received");
}
Some(Ok(_)) => {
warn!(%session_id, "unknown message");
}
Some(Err(e)) => {
error!(%session_id, "connection error: {}", e);
break
}
None => {
info!("connection closed");
break
}
}
}
// remove from connection pool
if let Some(pool) = pool.clone().upgrade() {
if pool.global_pool.write().remove_client(db_user.clone(), conn_id) {
info!("closed connection removed");
}
}
Poll::Ready(())
}).await;
}
.instrument(span));
let key = SigningKey::random(&mut OsRng);
let inner = ClientInner {
inner: client,
session: tx,
cancel,
aux,
conn_id,
key,
jti: 0,
};
LocalClient::new(inner, conn_info, pool_clone)
}
struct ClientInner<C: ClientInnerExt> {
inner: C,
session: tokio::sync::watch::Sender<uuid::Uuid>,
cancel: CancellationToken,
aux: MetricsAuxInfo,
conn_id: uuid::Uuid,
// needed for pg_session_jwt state
key: SigningKey,
jti: u64,
}
impl<C: ClientInnerExt> Drop for ClientInner<C> {
fn drop(&mut self) {
// on client drop, tell the conn to shut down
self.cancel.cancel();
}
}
impl<C: ClientInnerExt> ClientInner<C> {
pub(crate) fn is_closed(&self) -> bool {
self.inner.is_closed()
}
}
impl<C: ClientInnerExt> LocalClient<C> {
pub(crate) fn metrics(&self) -> Arc<MetricCounter> {
let aux = &self.inner.as_ref().unwrap().aux;
USAGE_METRICS.register(Ids {
endpoint_id: aux.endpoint_id,
branch_id: aux.branch_id,
})
}
}
pub(crate) struct LocalClient<C: ClientInnerExt> {
span: Span,
inner: Option<ClientInner<C>>,
conn_info: ConnInfo,
pool: Weak<LocalConnPool<C>>,
}
pub(crate) struct Discard<'a, C: ClientInnerExt> {
conn_info: &'a ConnInfo,
pool: &'a mut Weak<LocalConnPool<C>>,
}
impl<C: ClientInnerExt> LocalClient<C> {
pub(self) fn new(
inner: ClientInner<C>,
conn_info: ConnInfo,
pool: Weak<LocalConnPool<C>>,
) -> Self {
Self {
inner: Some(inner),
span: Span::current(),
conn_info,
pool,
}
}
pub(crate) fn inner(&mut self) -> (&mut C, Discard<'_, C>) {
let Self {
inner,
pool,
conn_info,
span: _,
} = self;
let inner = inner.as_mut().expect("client inner should not be removed");
(&mut inner.inner, Discard { conn_info, pool })
}
pub(crate) fn key(&self) -> &SigningKey {
let inner = &self
.inner
.as_ref()
.expect("client inner should not be removed");
&inner.key
}
}
impl LocalClient<tokio_postgres::Client> {
pub(crate) async fn set_jwt_session(&mut self, payload: &[u8]) -> Result<(), HttpConnError> {
let inner = self
.inner
.as_mut()
.expect("client inner should not be removed");
inner.jti += 1;
let kid = inner.inner.get_process_id();
let header = json!({"kid":kid}).to_string();
let mut payload = serde_json::from_slice::<serde_json::Map<String, Value>>(payload)
.map_err(HttpConnError::JwtPayloadError)?;
payload.insert("jti".to_string(), Value::Number(inner.jti.into()));
let payload = Value::Object(payload).to_string();
debug!(
kid,
jti = inner.jti,
?header,
?payload,
"signing new ephemeral JWT"
);
let token = sign_jwt(&inner.key, header, payload);
// initiates the auth session
inner.inner.simple_query("discard all").await?;
inner
.inner
.query(
"select auth.jwt_session_init($1)",
&[&token as &(dyn ToSql + Sync)],
)
.await?;
info!(kid, jti = inner.jti, "user session state init");
Ok(())
}
}
fn sign_jwt(sk: &SigningKey, header: String, payload: String) -> String {
let header = Base64UrlUnpadded::encode_string(header.as_bytes());
let payload = Base64UrlUnpadded::encode_string(payload.as_bytes());
let message = format!("{header}.{payload}");
let sig: Signature = sk.sign(message.as_bytes());
let base64_sig = Base64UrlUnpadded::encode_string(&sig.to_bytes());
format!("{message}.{base64_sig}")
}
impl<C: ClientInnerExt> Discard<'_, C> {
pub(crate) fn check_idle(&mut self, status: ReadyForQueryStatus) {
let conn_info = &self.conn_info;
if status != ReadyForQueryStatus::Idle && std::mem::take(self.pool).strong_count() > 0 {
info!(
"local_pool: throwing away connection '{conn_info}' because connection is not idle"
);
}
}
pub(crate) fn discard(&mut self) {
let conn_info = &self.conn_info;
if std::mem::take(self.pool).strong_count() > 0 {
info!("local_pool: throwing away connection '{conn_info}' because connection is potentially in a broken state");
}
}
}
impl<C: ClientInnerExt> LocalClient<C> {
pub fn get_client(&self) -> &C {
&self
.inner
.as_ref()
.expect("client inner should not be removed")
.inner
}
fn do_drop(&mut self) -> Option<impl FnOnce()> {
let conn_info = self.conn_info.clone();
let client = self
.inner
.take()
.expect("client inner should not be removed");
if let Some(conn_pool) = std::mem::take(&mut self.pool).upgrade() {
let current_span = self.span.clone();
// return connection to the pool
return Some(move || {
let _span = current_span.enter();
EndpointConnPool::put(&conn_pool.global_pool, &conn_info, client);
});
}
None
}
}
impl<C: ClientInnerExt> Drop for LocalClient<C> {
fn drop(&mut self) {
if let Some(drop) = self.do_drop() {
tokio::task::spawn_blocking(drop);
}
}
}

View File

@@ -8,7 +8,6 @@ mod conn_pool;
mod http_conn_pool;
mod http_util;
mod json;
mod local_conn_pool;
mod sql_over_http;
mod websocket;
@@ -23,7 +22,7 @@ use futures::TryFutureExt;
use http::{Method, Response, StatusCode};
use http_body_util::combinators::BoxBody;
use http_body_util::{BodyExt, Empty};
use hyper::body::Incoming;
use hyper1::body::Incoming;
use hyper_util::rt::TokioExecutor;
use hyper_util::server::conn::auto::Builder;
use rand::rngs::StdRng;
@@ -64,7 +63,6 @@ pub async fn task_main(
info!("websocket server has shut down");
}
let local_pool = local_conn_pool::LocalConnPool::new(&config.http_config);
let conn_pool = conn_pool::GlobalConnPool::new(&config.http_config);
{
let conn_pool = Arc::clone(&conn_pool);
@@ -107,7 +105,6 @@ pub async fn task_main(
let backend = Arc::new(PoolingBackend {
http_conn_pool: Arc::clone(&http_conn_pool),
local_pool,
pool: Arc::clone(&conn_pool),
config,
endpoint_rate_limiter: Arc::clone(&endpoint_rate_limiter),
@@ -305,7 +302,7 @@ async fn connection_handler(
let server = Builder::new(TokioExecutor::new());
let conn = server.serve_connection_with_upgrades(
hyper_util::rt::TokioIo::new(conn),
hyper::service::service_fn(move |req: hyper::Request<Incoming>| {
hyper1::service::service_fn(move |req: hyper1::Request<Incoming>| {
// First HTTP request shares the same session ID
let session_id = session_id.take().unwrap_or_else(uuid::Uuid::new_v4);
@@ -358,7 +355,7 @@ async fn connection_handler(
#[allow(clippy::too_many_arguments)]
async fn request_handler(
mut request: hyper::Request<Incoming>,
mut request: hyper1::Request<Incoming>,
config: &'static ProxyConfig,
backend: Arc<PoolingBackend>,
ws_connections: TaskTracker,
@@ -368,7 +365,7 @@ async fn request_handler(
// used to cancel in-flight HTTP requests. not used to cancel websockets
http_cancellation_token: CancellationToken,
endpoint_rate_limiter: Arc<EndpointRateLimiter>,
) -> Result<Response<BoxBody<Bytes, hyper::Error>>, ApiError> {
) -> Result<Response<BoxBody<Bytes, hyper1::Error>>, ApiError> {
let host = request
.headers()
.get("host")

View File

@@ -12,14 +12,14 @@ use http::Method;
use http_body_util::combinators::BoxBody;
use http_body_util::BodyExt;
use http_body_util::Full;
use hyper::body::Body;
use hyper::body::Incoming;
use hyper::header;
use hyper::http::HeaderName;
use hyper::http::HeaderValue;
use hyper::Response;
use hyper::StatusCode;
use hyper::{HeaderMap, Request};
use hyper1::body::Body;
use hyper1::body::Incoming;
use hyper1::header;
use hyper1::http::HeaderName;
use hyper1::http::HeaderValue;
use hyper1::Response;
use hyper1::StatusCode;
use hyper1::{HeaderMap, Request};
use pq_proto::StartupMessageParamsBuilder;
use serde::Serialize;
use serde_json::Value;
@@ -40,7 +40,7 @@ use url::Url;
use urlencoding;
use utils::http::error::ApiError;
use crate::auth::backend::ComputeCredentialKeys;
use crate::auth::backend::ComputeCredentials;
use crate::auth::backend::ComputeUserInfo;
use crate::auth::endpoint_sni;
use crate::auth::ComputeUserInfoParseError;
@@ -56,22 +56,20 @@ use crate::metrics::Metrics;
use crate::proxy::run_until_cancelled;
use crate::proxy::NeonOptions;
use crate::serverless::backend::HttpConnError;
use crate::usage_metrics::MetricCounter;
use crate::usage_metrics::MetricCounterRecorder;
use crate::DbName;
use crate::RoleName;
use super::backend::LocalProxyConnError;
use super::backend::PoolingBackend;
use super::conn_pool;
use super::conn_pool::AuthData;
use super::conn_pool::Client;
use super::conn_pool::ConnInfo;
use super::conn_pool::ConnInfoWithAuth;
use super::http_util::json_response;
use super::json::json_to_pg_text;
use super::json::pg_text_row_to_json;
use super::json::JsonConversionError;
use super::local_conn_pool;
#[derive(serde::Deserialize)]
#[serde(rename_all = "camelCase")]
@@ -274,7 +272,7 @@ pub(crate) async fn handle(
request: Request<Incoming>,
backend: Arc<PoolingBackend>,
cancel: CancellationToken,
) -> Result<Response<BoxBody<Bytes, hyper::Error>>, ApiError> {
) -> Result<Response<BoxBody<Bytes, hyper1::Error>>, ApiError> {
let result = handle_inner(cancel, config, &ctx, request, backend).await;
let mut response = match result {
@@ -437,7 +435,7 @@ impl UserFacingError for SqlOverHttpError {
#[derive(Debug, thiserror::Error)]
pub(crate) enum ReadPayloadError {
#[error("could not read the HTTP request body: {0}")]
Read(#[from] hyper::Error),
Read(#[from] hyper1::Error),
#[error("could not parse the HTTP request body: {0}")]
Parse(#[from] serde_json::Error),
}
@@ -478,7 +476,7 @@ struct HttpHeaders {
}
impl HttpHeaders {
fn try_parse(headers: &hyper::http::HeaderMap) -> Result<Self, SqlOverHttpError> {
fn try_parse(headers: &hyper1::http::HeaderMap) -> Result<Self, SqlOverHttpError> {
// Determine the output options. Default behaviour is 'false'. Anything that is not
// strictly 'true' assumed to be false.
let raw_output = headers.get(&RAW_TEXT_OUTPUT) == Some(&HEADER_VALUE_TRUE);
@@ -531,7 +529,7 @@ async fn handle_inner(
ctx: &RequestMonitoring,
request: Request<Incoming>,
backend: Arc<PoolingBackend>,
) -> Result<Response<BoxBody<Bytes, hyper::Error>>, SqlOverHttpError> {
) -> Result<Response<BoxBody<Bytes, hyper1::Error>>, SqlOverHttpError> {
let _requeset_gauge = Metrics::get()
.proxy
.connection_requests
@@ -579,7 +577,7 @@ async fn handle_db_inner(
conn_info: ConnInfo,
auth: AuthData,
backend: Arc<PoolingBackend>,
) -> Result<Response<BoxBody<Bytes, hyper::Error>>, SqlOverHttpError> {
) -> Result<Response<BoxBody<Bytes, hyper1::Error>>, SqlOverHttpError> {
//
// Determine the destination and connection params
//
@@ -622,9 +620,6 @@ async fn handle_db_inner(
let authenticate_and_connect = Box::pin(
async {
let is_local_proxy =
matches!(backend.config.auth_backend, crate::auth::Backend::Local(_));
let keys = match auth {
AuthData::Password(pw) => {
backend
@@ -644,24 +639,18 @@ async fn handle_db_inner(
&conn_info.user_info,
jwt,
)
.await?
}
};
let client = match keys.keys {
ComputeCredentialKeys::JwtPayload(payload) if is_local_proxy => {
let mut client = backend.connect_to_local_postgres(ctx, conn_info).await?;
client.set_jwt_session(&payload).await?;
Client::Local(client)
}
_ => {
let client = backend
.connect_to_compute(ctx, conn_info, keys, !allow_pool)
.await?;
Client::Remote(client)
ComputeCredentials {
info: conn_info.user_info.clone(),
keys: crate::auth::backend::ComputeCredentialKeys::None,
}
}
};
let client = backend
.connect_to_compute(ctx, conn_info, keys, !allow_pool)
.await?;
// not strictly necessary to mark success here,
// but it's just insurance for if we forget it somewhere else
ctx.success();
@@ -755,7 +744,7 @@ async fn handle_auth_broker_inner(
conn_info: ConnInfo,
jwt: String,
backend: Arc<PoolingBackend>,
) -> Result<Response<BoxBody<Bytes, hyper::Error>>, SqlOverHttpError> {
) -> Result<Response<BoxBody<Bytes, hyper1::Error>>, SqlOverHttpError> {
backend
.authenticate_with_jwt(
ctx,
@@ -802,7 +791,7 @@ impl QueryData {
self,
config: &'static ProxyConfig,
cancel: CancellationToken,
client: &mut Client,
client: &mut Client<tokio_postgres::Client>,
parsed_headers: HttpHeaders,
) -> Result<String, SqlOverHttpError> {
let (inner, mut discard) = client.inner();
@@ -876,7 +865,7 @@ impl BatchQueryData {
self,
config: &'static ProxyConfig,
cancel: CancellationToken,
client: &mut Client,
client: &mut Client<tokio_postgres::Client>,
parsed_headers: HttpHeaders,
) -> Result<String, SqlOverHttpError> {
info!("starting transaction");
@@ -1069,50 +1058,3 @@ async fn query_to_json<T: GenericClient>(
Ok((ready, results))
}
enum Client {
Remote(conn_pool::Client<tokio_postgres::Client>),
Local(local_conn_pool::LocalClient<tokio_postgres::Client>),
}
enum Discard<'a> {
Remote(conn_pool::Discard<'a, tokio_postgres::Client>),
Local(local_conn_pool::Discard<'a, tokio_postgres::Client>),
}
impl Client {
fn metrics(&self) -> Arc<MetricCounter> {
match self {
Client::Remote(client) => client.metrics(),
Client::Local(local_client) => local_client.metrics(),
}
}
fn inner(&mut self) -> (&mut tokio_postgres::Client, Discard<'_>) {
match self {
Client::Remote(client) => {
let (c, d) = client.inner();
(c, Discard::Remote(d))
}
Client::Local(local_client) => {
let (c, d) = local_client.inner();
(c, Discard::Local(d))
}
}
}
}
impl Discard<'_> {
fn check_idle(&mut self, status: ReadyForQueryStatus) {
match self {
Discard::Remote(discard) => discard.check_idle(status),
Discard::Local(discard) => discard.check_idle(status),
}
}
fn discard(&mut self) {
match self {
Discard::Remote(discard) => discard.discard(),
Discard::Local(discard) => discard.discard(),
}
}
}

View File

@@ -12,7 +12,7 @@ use anyhow::Context as _;
use bytes::{Buf, BufMut, Bytes, BytesMut};
use framed_websockets::{Frame, OpCode, WebSocketServer};
use futures::{Sink, Stream};
use hyper::upgrade::OnUpgrade;
use hyper1::upgrade::OnUpgrade;
use hyper_util::rt::TokioIo;
use pin_project_lite::pin_project;

View File

@@ -485,51 +485,49 @@ async fn upload_events_chunk(
#[cfg(test)]
mod tests {
use super::*;
use std::{
net::TcpListener,
sync::{Arc, Mutex},
};
use crate::{http, BranchId, EndpointId};
use anyhow::Error;
use chrono::Utc;
use consumption_metrics::{Event, EventChunk};
use http_body_util::BodyExt;
use hyper::{body::Incoming, server::conn::http1, service::service_fn, Request, Response};
use hyper_util::rt::TokioIo;
use std::sync::{Arc, Mutex};
use tokio::net::TcpListener;
use hyper::{
service::{make_service_fn, service_fn},
Body, Response,
};
use url::Url;
use super::*;
use crate::{http, BranchId, EndpointId};
#[tokio::test]
async fn metrics() {
type Report = EventChunk<'static, Event<Ids, String>>;
let reports: Arc<Mutex<Vec<Report>>> = Arc::default();
let listener = TcpListener::bind("0.0.0.0:0").unwrap();
let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
let addr = listener.local_addr().unwrap();
tokio::spawn({
let reports = reports.clone();
async move {
loop {
if let Ok((stream, _addr)) = listener.accept().await {
let reports = Arc::new(Mutex::new(vec![]));
let reports2 = reports.clone();
let server = hyper::server::Server::from_tcp(listener)
.unwrap()
.serve(make_service_fn(move |_| {
let reports = reports.clone();
async move {
Ok::<_, Error>(service_fn(move |req| {
let reports = reports.clone();
http1::Builder::new()
.serve_connection(
TokioIo::new(stream),
service_fn(move |req: Request<Incoming>| {
let reports = reports.clone();
async move {
let bytes = req.into_body().collect().await?.to_bytes();
let events = serde_json::from_slice(&bytes)?;
reports.lock().unwrap().push(events);
Ok::<_, Error>(Response::new(String::new()))
}
}),
)
.await
.unwrap();
}
async move {
let bytes = hyper::body::to_bytes(req.into_body()).await?;
let events: EventChunk<'static, Event<Ids, String>> =
serde_json::from_slice(&bytes)?;
reports.lock().unwrap().push(events);
Ok::<_, Error>(Response::new(Body::from(vec![])))
}
}))
}
}
});
}));
let addr = server.local_addr();
tokio::spawn(server);
let metrics = Metrics::default();
let client = http::new_client();
@@ -538,7 +536,7 @@ mod tests {
// no counters have been registered
collect_metrics_iteration(&metrics.endpoints, &client, &endpoint, "foo", now, now).await;
let r = std::mem::take(&mut *reports.lock().unwrap());
let r = std::mem::take(&mut *reports2.lock().unwrap());
assert!(r.is_empty());
// register a new counter
@@ -550,7 +548,7 @@ mod tests {
// the counter should be observed despite 0 egress
collect_metrics_iteration(&metrics.endpoints, &client, &endpoint, "foo", now, now).await;
let r = std::mem::take(&mut *reports.lock().unwrap());
let r = std::mem::take(&mut *reports2.lock().unwrap());
assert_eq!(r.len(), 1);
assert_eq!(r[0].events.len(), 1);
assert_eq!(r[0].events[0].value, 0);
@@ -560,7 +558,7 @@ mod tests {
// egress should be observered
collect_metrics_iteration(&metrics.endpoints, &client, &endpoint, "foo", now, now).await;
let r = std::mem::take(&mut *reports.lock().unwrap());
let r = std::mem::take(&mut *reports2.lock().unwrap());
assert_eq!(r.len(), 1);
assert_eq!(r[0].events.len(), 1);
assert_eq!(r[0].events[0].value, 1);
@@ -570,7 +568,7 @@ mod tests {
// we do not observe the counter
collect_metrics_iteration(&metrics.endpoints, &client, &endpoint, "foo", now, now).await;
let r = std::mem::take(&mut *reports.lock().unwrap());
let r = std::mem::take(&mut *reports2.lock().unwrap());
assert!(r.is_empty());
// counter is unregistered

View File

@@ -97,8 +97,5 @@ select = [
"I", # isort
"W", # pycodestyle
"B", # bugbear
"UP", # pyupgrade
"UP032", # f-string
]
[tool.ruff.lint.pyupgrade]
keep-runtime-typing = true # Remove this stanza when we require Python 3.10

View File

@@ -12,8 +12,8 @@ use metrics::{
core::{AtomicU64, Collector, Desc, GenericCounter, GenericGaugeVec, Opts},
proto::MetricFamily,
register_histogram_vec, register_int_counter, register_int_counter_pair,
register_int_counter_pair_vec, register_int_counter_vec, register_int_gauge, Gauge,
HistogramVec, IntCounter, IntCounterPair, IntCounterPairVec, IntCounterVec, IntGaugeVec,
register_int_counter_pair_vec, register_int_counter_vec, Gauge, HistogramVec, IntCounter,
IntCounterPair, IntCounterPairVec, IntCounterVec, IntGaugeVec,
};
use once_cell::sync::Lazy;
@@ -231,14 +231,6 @@ pub(crate) static EVICTION_EVENTS_COMPLETED: Lazy<IntCounterVec> = Lazy::new(||
.expect("Failed to register metric")
});
pub static NUM_EVICTED_TIMELINES: Lazy<IntGauge> = Lazy::new(|| {
register_int_gauge!(
"safekeeper_evicted_timelines",
"Number of currently evicted timelines"
)
.expect("Failed to register metric")
});
pub const LABEL_UNKNOWN: &str = "unknown";
/// Labels for traffic metrics.

View File

@@ -631,19 +631,13 @@ impl Timeline {
return Err(e);
}
self.bootstrap(
shared_state,
conf,
broker_active_set,
partial_backup_rate_limiter,
);
self.bootstrap(conf, broker_active_set, partial_backup_rate_limiter);
Ok(())
}
/// Bootstrap new or existing timeline starting background tasks.
pub fn bootstrap(
self: &Arc<Timeline>,
_shared_state: &mut WriteGuardSharedState<'_>,
conf: &SafeKeeperConf,
broker_active_set: Arc<TimelinesSet>,
partial_backup_rate_limiter: RateLimiter,

View File

@@ -15,9 +15,7 @@ use tracing::{debug, info, instrument, warn};
use utils::crashsafe::durable_rename;
use crate::{
metrics::{
EvictionEvent, EVICTION_EVENTS_COMPLETED, EVICTION_EVENTS_STARTED, NUM_EVICTED_TIMELINES,
},
metrics::{EvictionEvent, EVICTION_EVENTS_COMPLETED, EVICTION_EVENTS_STARTED},
rate_limit::rand_duration,
timeline_manager::{Manager, StateSnapshot},
wal_backup,
@@ -95,7 +93,6 @@ impl Manager {
}
info!("successfully evicted timeline");
NUM_EVICTED_TIMELINES.inc();
}
/// Attempt to restore evicted timeline from remote storage; it must be
@@ -131,7 +128,6 @@ impl Manager {
tokio::time::Instant::now() + rand_duration(&self.conf.eviction_min_resident);
info!("successfully restored evicted timeline");
NUM_EVICTED_TIMELINES.dec();
}
}

View File

@@ -25,10 +25,7 @@ use utils::lsn::Lsn;
use crate::{
control_file::{FileStorage, Storage},
metrics::{
MANAGER_ACTIVE_CHANGES, MANAGER_ITERATIONS_TOTAL, MISC_OPERATION_SECONDS,
NUM_EVICTED_TIMELINES,
},
metrics::{MANAGER_ACTIVE_CHANGES, MANAGER_ITERATIONS_TOTAL, MISC_OPERATION_SECONDS},
rate_limit::{rand_duration, RateLimiter},
recovery::recovery_main,
remove_wal::calc_horizon_lsn,
@@ -254,11 +251,6 @@ pub async fn main_task(
mgr.recovery_task = Some(tokio::spawn(recovery_main(tli, mgr.conf.clone())));
}
// If timeline is evicted, reflect that in the metric.
if mgr.is_offloaded {
NUM_EVICTED_TIMELINES.inc();
}
let last_state = 'outer: loop {
MANAGER_ITERATIONS_TOTAL.inc();
@@ -375,11 +367,6 @@ pub async fn main_task(
mgr.update_wal_removal_end(res);
}
// If timeline is deleted while evicted decrement the gauge.
if mgr.tli.is_cancelled() && mgr.is_offloaded {
NUM_EVICTED_TIMELINES.dec();
}
mgr.set_status(Status::Finished);
}

View File

@@ -165,14 +165,12 @@ impl GlobalTimelines {
match Timeline::load_timeline(&conf, ttid) {
Ok(timeline) => {
let tli = Arc::new(timeline);
let mut shared_state = tli.write_shared_state().await;
TIMELINES_STATE
.lock()
.unwrap()
.timelines
.insert(ttid, tli.clone());
tli.bootstrap(
&mut shared_state,
&conf,
broker_active_set.clone(),
partial_backup_rate_limiter.clone(),
@@ -215,7 +213,6 @@ impl GlobalTimelines {
match Timeline::load_timeline(&conf, ttid) {
Ok(timeline) => {
let tli = Arc::new(timeline);
let mut shared_state = tli.write_shared_state().await;
// TODO: prevent concurrent timeline creation/loading
{
@@ -230,13 +227,8 @@ impl GlobalTimelines {
state.timelines.insert(ttid, tli.clone());
}
tli.bootstrap(
&mut shared_state,
&conf,
broker_active_set,
partial_backup_rate_limiter,
);
drop(shared_state);
tli.bootstrap(&conf, broker_active_set, partial_backup_rate_limiter);
Ok(tli)
}
// If we can't load a timeline, it's bad. Caller will figure it out.

View File

@@ -17,9 +17,7 @@ use std::time::Duration;
use postgres_ffi::v14::xlog_utils::XLogSegNoOffsetToRecPtr;
use postgres_ffi::XLogFileName;
use postgres_ffi::{XLogSegNo, PG_TLI};
use remote_storage::{
DownloadOpts, GenericRemoteStorage, ListingMode, RemotePath, StorageMetadata,
};
use remote_storage::{GenericRemoteStorage, ListingMode, RemotePath, StorageMetadata};
use tokio::fs::File;
use tokio::select;
@@ -505,12 +503,8 @@ pub async fn read_object(
let cancel = CancellationToken::new();
let opts = DownloadOpts {
byte_start: std::ops::Bound::Included(offset),
..Default::default()
};
let download = storage
.download(file_path, &opts, &cancel)
.download_storage_object(Some((offset, None)), file_path, &cancel)
.await
.with_context(|| {
format!("Failed to open WAL segment download stream for remote path {file_path:?}")

View File

@@ -1,10 +1,9 @@
#! /usr/bin/env python3
from __future__ import annotations
import argparse
import json
import logging
from typing import Dict
import psycopg2
import psycopg2.extras
@@ -111,7 +110,7 @@ def main(args: argparse.Namespace):
output = args.output
percentile = args.percentile
res: dict[str, float] = {}
res: Dict[str, float] = {}
try:
logging.info("connecting to the database...")

View File

@@ -4,9 +4,6 @@
#
# This can be useful in disaster recovery.
#
from __future__ import annotations
import argparse
import psycopg2

View File

@@ -1,21 +1,16 @@
#! /usr/bin/env python3
from __future__ import annotations
import argparse
import json
import logging
import os
from collections import defaultdict
from typing import TYPE_CHECKING
from typing import Any, DefaultDict, Dict, Optional
import psycopg2
import psycopg2.extras
import toml
if TYPE_CHECKING:
from typing import Any, Optional
FLAKY_TESTS_QUERY = """
SELECT
DISTINCT parent_suite, suite, name
@@ -38,7 +33,7 @@ def main(args: argparse.Namespace):
build_type = args.build_type
pg_version = args.pg_version
res: defaultdict[str, defaultdict[str, dict[str, bool]]]
res: DefaultDict[str, DefaultDict[str, Dict[str, bool]]]
res = defaultdict(lambda: defaultdict(dict))
try:
@@ -65,7 +60,7 @@ def main(args: argparse.Namespace):
pageserver_virtual_file_io_engine_parameter = ""
# re-use existing records of flaky tests from before parametrization by compaction_algorithm
def get_pageserver_default_tenant_config_compaction_algorithm() -> Optional[dict[str, Any]]:
def get_pageserver_default_tenant_config_compaction_algorithm() -> Optional[Dict[str, Any]]:
"""Duplicated from parametrize.py"""
toml_table = os.getenv("PAGESERVER_DEFAULT_TENANT_CONFIG_COMPACTION_ALGORITHM")
if toml_table is None:

View File

@@ -1,5 +1,3 @@
from __future__ import annotations
import argparse
import asyncio
import json
@@ -7,15 +5,11 @@ import logging
import signal
import sys
from collections import defaultdict
from collections.abc import Awaitable
from dataclasses import dataclass
from typing import TYPE_CHECKING
from typing import Any, Awaitable, Dict, List, Tuple
import aiohttp
if TYPE_CHECKING:
from typing import Any
class ClientException(Exception):
pass
@@ -95,7 +89,7 @@ class Client:
class Completed:
"""The status dict returned by the API"""
status: dict[str, Any]
status: Dict[str, Any]
sigint_received = asyncio.Event()
@@ -185,7 +179,7 @@ async def main_impl(args, report_out, client: Client):
"""
Returns OS exit status.
"""
tenant_and_timline_ids: list[tuple[str, str]] = []
tenant_and_timline_ids: List[Tuple[str, str]] = []
# fill tenant_and_timline_ids based on spec
for spec in args.what:
comps = spec.split(":")
@@ -221,14 +215,14 @@ async def main_impl(args, report_out, client: Client):
tenant_and_timline_ids = tmp
logging.info("create tasks and process them at specified concurrency")
task_q: asyncio.Queue[tuple[str, Awaitable[Any]]] = asyncio.Queue()
task_q: asyncio.Queue[Tuple[str, Awaitable[Any]]] = asyncio.Queue()
tasks = {
f"{tid}:{tlid}": do_timeline(client, tid, tlid) for tid, tlid in tenant_and_timline_ids
}
for task in tasks.items():
task_q.put_nowait(task)
result_q: asyncio.Queue[tuple[str, Any]] = asyncio.Queue()
result_q: asyncio.Queue[Tuple[str, Any]] = asyncio.Queue()
taskq_handlers = []
for _ in range(0, args.concurrent_tasks):
taskq_handlers.append(taskq_handler(task_q, result_q))

View File

@@ -1,7 +1,4 @@
#!/usr/bin/env python3
from __future__ import annotations
import argparse
import json
import logging

View File

@@ -1,7 +1,5 @@
#! /usr/bin/env python3
from __future__ import annotations
import argparse
import dataclasses
import json
@@ -13,6 +11,7 @@ from contextlib import contextmanager
from dataclasses import dataclass
from datetime import datetime, timezone
from pathlib import Path
from typing import Tuple
import backoff
import psycopg2
@@ -92,7 +91,7 @@ def create_table(cur):
cur.execute(CREATE_TABLE)
def parse_test_name(test_name: str) -> tuple[str, int, str]:
def parse_test_name(test_name: str) -> Tuple[str, int, str]:
build_type, pg_version = None, None
if match := TEST_NAME_RE.search(test_name):
found = match.groupdict()

View File

@@ -1,5 +1,3 @@
from __future__ import annotations
import argparse
import logging
import os

View File

@@ -22,7 +22,7 @@ use utils::sync::gate::GateGuard;
use crate::compute_hook::{ComputeHook, NotifyError};
use crate::node::Node;
use crate::tenant_shard::{IntentState, ObservedState, ObservedStateDelta, ObservedStateLocation};
use crate::tenant_shard::{IntentState, ObservedState, ObservedStateLocation};
const DEFAULT_HEATMAP_PERIOD: &str = "60s";
@@ -45,15 +45,8 @@ pub(super) struct Reconciler {
pub(crate) reconciler_config: ReconcilerConfig,
pub(crate) config: TenantConfig,
/// Observed state from the point of view of the reconciler.
/// This gets updated as the reconciliation makes progress.
pub(crate) observed: ObservedState,
/// Snapshot of the observed state at the point when the reconciler
/// was spawned.
pub(crate) original_observed: ObservedState,
pub(crate) service_config: service::Config,
/// A hook to notify the running postgres instances when we change the location
@@ -853,39 +846,6 @@ impl Reconciler {
}
}
/// Compare the observed state snapshot from when the reconcile was created
/// with the final observed state in order to generate observed state deltas.
pub(crate) fn observed_deltas(&self) -> Vec<ObservedStateDelta> {
let mut deltas = Vec::default();
for (node_id, location) in &self.observed.locations {
let previous_location = self.original_observed.locations.get(node_id);
let do_upsert = match previous_location {
// Location config changed for node
Some(prev) if location.conf != prev.conf => true,
// New location config for node
None => true,
// Location config has not changed for node
_ => false,
};
if do_upsert {
deltas.push(ObservedStateDelta::Upsert(Box::new((
*node_id,
location.clone(),
))));
}
}
for node_id in self.original_observed.locations.keys() {
if !self.observed.locations.contains_key(node_id) {
deltas.push(ObservedStateDelta::Delete(*node_id));
}
}
deltas
}
/// Keep trying to notify the compute indefinitely, only dropping out if:
/// - the node `origin` becomes unavailable -> Ok(())
/// - the node `origin` no longer has our tenant shard attached -> Ok(())

View File

@@ -28,8 +28,8 @@ use crate::{
reconciler::{ReconcileError, ReconcileUnits, ReconcilerConfig, ReconcilerConfigBuilder},
scheduler::{MaySchedule, ScheduleContext, ScheduleError, ScheduleMode},
tenant_shard::{
MigrateAttachment, ObservedStateDelta, ReconcileNeeded, ReconcilerStatus,
ScheduleOptimization, ScheduleOptimizationAction,
MigrateAttachment, ReconcileNeeded, ReconcilerStatus, ScheduleOptimization,
ScheduleOptimizationAction,
},
};
use anyhow::Context;
@@ -966,8 +966,6 @@ impl Service {
let res = self.heartbeater.heartbeat(nodes).await;
if let Ok(deltas) = res {
let mut to_handle = Vec::default();
for (node_id, state) in deltas.0 {
let new_availability = match state {
PageserverState::Available { utilization, .. } => {
@@ -999,27 +997,14 @@ impl Service {
}
};
let node_lock = trace_exclusive_lock(
&self.node_op_locks,
node_id,
NodeOperations::Configure,
)
.await;
// This is the code path for geniune availability transitions (i.e node
// goes unavailable and/or comes back online).
let res = self
.node_state_configure(node_id, Some(new_availability), None, &node_lock)
.node_configure(node_id, Some(new_availability), None)
.await;
match res {
Ok(transition) => {
// Keep hold of the lock until the availability transitions
// have been handled in
// [`Service::handle_node_availability_transitions`] in order avoid
// racing with [`Service::external_node_configure`].
to_handle.push((node_id, node_lock, transition));
}
Ok(()) => {}
Err(ApiError::NotFound(_)) => {
// This should be rare, but legitimate since the heartbeats are done
// on a snapshot of the nodes.
@@ -1029,37 +1014,13 @@ impl Service {
// Transition to active involves reconciling: if a node responds to a heartbeat then
// becomes unavailable again, we may get an error here.
tracing::error!(
"Failed to update node state {} after heartbeat round: {}",
"Failed to update node {} after heartbeat round: {}",
node_id,
err
);
}
}
}
// We collected all the transitions above and now we handle them.
let res = self.handle_node_availability_transitions(to_handle).await;
if let Err(errs) = res {
for (node_id, err) in errs {
match err {
ApiError::NotFound(_) => {
// This should be rare, but legitimate since the heartbeats are done
// on a snapshot of the nodes.
tracing::info!(
"Node {} was not found after heartbeat round",
node_id
);
}
err => {
tracing::error!(
"Failed to handle availability transition for {} after heartbeat round: {}",
node_id,
err
);
}
}
}
}
}
}
}
@@ -1072,7 +1033,7 @@ impl Service {
tenant_id=%result.tenant_shard_id.tenant_id, shard_id=%result.tenant_shard_id.shard_slug(),
sequence=%result.sequence
))]
fn process_result(&self, result: ReconcileResult) {
fn process_result(&self, mut result: ReconcileResult) {
let mut locked = self.inner.write().unwrap();
let (nodes, tenants, _scheduler) = locked.parts_mut();
let Some(tenant) = tenants.get_mut(&result.tenant_shard_id) else {
@@ -1094,27 +1055,22 @@ impl Service {
// In case a node was deleted while this reconcile is in flight, filter it out of the update we will
// make to the tenant
let deltas = result.observed_deltas.into_iter().flat_map(|delta| {
// In case a node was deleted while this reconcile is in flight, filter it out of the update we will
// make to the tenant
let node = nodes.get(delta.node_id())?;
if node.is_available() {
return Some(delta);
}
// In case a node became unavailable concurrently with the reconcile, observed
// locations on it are now uncertain. By convention, set them to None in order
// for them to get refreshed when the node comes back online.
Some(ObservedStateDelta::Upsert(Box::new((
node.get_id(),
ObservedStateLocation { conf: None },
))))
});
result
.observed
.locations
.retain(|node_id, _loc| nodes.contains_key(node_id));
match result.result {
Ok(()) => {
tenant.apply_observed_deltas(deltas);
for (node_id, loc) in &result.observed.locations {
if let Some(conf) = &loc.conf {
tracing::info!("Updating observed location {}: {:?}", node_id, conf);
} else {
tracing::info!("Setting observed location {} to None", node_id,)
}
}
tenant.observed = result.observed;
tenant.waiter.advance(result.sequence);
}
Err(e) => {
@@ -1136,10 +1092,9 @@ impl Service {
// so that waiters will see the correct error after waiting.
tenant.set_last_error(result.sequence, e);
// Skip deletions on reconcile failures
let upsert_deltas =
deltas.filter(|delta| matches!(delta, ObservedStateDelta::Upsert(_)));
tenant.apply_observed_deltas(upsert_deltas);
for (node_id, o) in result.observed.locations {
tenant.observed.locations.insert(node_id, o);
}
}
}
@@ -5344,17 +5299,15 @@ impl Service {
Ok(())
}
/// Configure in-memory and persistent state of a node as requested
///
/// Note that this function does not trigger any immediate side effects in response
/// to the changes. That part is handled by [`Self::handle_node_availability_transition`].
async fn node_state_configure(
pub(crate) async fn node_configure(
&self,
node_id: NodeId,
availability: Option<NodeAvailability>,
scheduling: Option<NodeSchedulingPolicy>,
node_lock: &TracingExclusiveGuard<NodeOperations>,
) -> Result<AvailabilityTransition, ApiError> {
) -> Result<(), ApiError> {
let _node_lock =
trace_exclusive_lock(&self.node_op_locks, node_id, NodeOperations::Configure).await;
if let Some(scheduling) = scheduling {
// Scheduling is a persistent part of Node: we must write updates to the database before
// applying them in memory
@@ -5383,7 +5336,7 @@ impl Service {
};
if matches!(availability_transition, AvailabilityTransition::ToActive) {
self.node_activate_reconcile(activate_node, node_lock)
self.node_activate_reconcile(activate_node, &_node_lock)
.await?;
}
availability_transition
@@ -5393,7 +5346,7 @@ impl Service {
// Apply changes from the request to our in-memory state for the Node
let mut locked = self.inner.write().unwrap();
let (nodes, _tenants, scheduler) = locked.parts_mut();
let (nodes, tenants, scheduler) = locked.parts_mut();
let mut new_nodes = (**nodes).clone();
@@ -5403,8 +5356,8 @@ impl Service {
));
};
if let Some(availability) = availability {
node.set_availability(availability);
if let Some(availability) = availability.as_ref() {
node.set_availability(availability.clone());
}
if let Some(scheduling) = scheduling {
@@ -5415,30 +5368,11 @@ impl Service {
scheduler.node_upsert(node);
let new_nodes = Arc::new(new_nodes);
locked.nodes = new_nodes;
Ok(availability_transition)
}
/// Handle availability transition of one node
///
/// Note that you should first call [`Self::node_state_configure`] to update
/// the in-memory state referencing that node. If you need to handle more than one transition
/// consider using [`Self::handle_node_availability_transitions`].
async fn handle_node_availability_transition(
&self,
node_id: NodeId,
transition: AvailabilityTransition,
_node_lock: &TracingExclusiveGuard<NodeOperations>,
) -> Result<(), ApiError> {
// Modify scheduling state for any Tenants that are affected by a change in the node's availability state.
match transition {
match availability_transition {
AvailabilityTransition::ToOffline => {
tracing::info!("Node {} transition to offline", node_id);
let mut locked = self.inner.write().unwrap();
let (nodes, tenants, scheduler) = locked.parts_mut();
let mut tenants_affected: usize = 0;
for (tenant_shard_id, tenant_shard) in tenants {
@@ -5448,14 +5382,14 @@ impl Service {
observed_loc.conf = None;
}
if nodes.len() == 1 {
if new_nodes.len() == 1 {
// Special case for single-node cluster: there is no point trying to reschedule
// any tenant shards: avoid doing so, in order to avoid spewing warnings about
// failures to schedule them.
continue;
}
if !nodes
if !new_nodes
.values()
.any(|n| matches!(n.may_schedule(), MaySchedule::Yes(_)))
{
@@ -5481,7 +5415,10 @@ impl Service {
tracing::warn!(%tenant_shard_id, "Scheduling error when marking pageserver {} offline: {e}", node_id);
}
Ok(()) => {
if self.maybe_reconcile_shard(tenant_shard, nodes).is_some() {
if self
.maybe_reconcile_shard(tenant_shard, &new_nodes)
.is_some()
{
tenants_affected += 1;
};
}
@@ -5496,13 +5433,9 @@ impl Service {
}
AvailabilityTransition::ToActive => {
tracing::info!("Node {} transition to active", node_id);
let mut locked = self.inner.write().unwrap();
let (nodes, tenants, _scheduler) = locked.parts_mut();
// When a node comes back online, we must reconcile any tenant that has a None observed
// location on the node.
for tenant_shard in tenants.values_mut() {
for tenant_shard in locked.tenants.values_mut() {
// If a reconciliation is already in progress, rely on the previous scheduling
// decision and skip triggering a new reconciliation.
if tenant_shard.reconciler.is_some() {
@@ -5511,7 +5444,7 @@ impl Service {
if let Some(observed_loc) = tenant_shard.observed.locations.get_mut(&node_id) {
if observed_loc.conf.is_none() {
self.maybe_reconcile_shard(tenant_shard, nodes);
self.maybe_reconcile_shard(tenant_shard, &new_nodes);
}
}
}
@@ -5532,54 +5465,11 @@ impl Service {
}
}
locked.nodes = new_nodes;
Ok(())
}
/// Handle availability transition for multiple nodes
///
/// Note that you should first call [`Self::node_state_configure`] for
/// all nodes being handled here for the handling to use fresh in-memory state.
async fn handle_node_availability_transitions(
&self,
transitions: Vec<(
NodeId,
TracingExclusiveGuard<NodeOperations>,
AvailabilityTransition,
)>,
) -> Result<(), Vec<(NodeId, ApiError)>> {
let mut errors = Vec::default();
for (node_id, node_lock, transition) in transitions {
let res = self
.handle_node_availability_transition(node_id, transition, &node_lock)
.await;
if let Err(err) = res {
errors.push((node_id, err));
}
}
if errors.is_empty() {
Ok(())
} else {
Err(errors)
}
}
pub(crate) async fn node_configure(
&self,
node_id: NodeId,
availability: Option<NodeAvailability>,
scheduling: Option<NodeSchedulingPolicy>,
) -> Result<(), ApiError> {
let node_lock =
trace_exclusive_lock(&self.node_op_locks, node_id, NodeOperations::Configure).await;
let transition = self
.node_state_configure(node_id, availability, scheduling, &node_lock)
.await?;
self.handle_node_availability_transition(node_id, transition, &node_lock)
.await
}
/// Wrapper around [`Self::node_configure`] which only allows changes while there is no ongoing
/// operation for HTTP api.
pub(crate) async fn external_node_configure(

View File

@@ -425,22 +425,6 @@ pub(crate) enum ReconcileNeeded {
Yes,
}
/// Pending modification to the observed state of a tenant shard.
/// Produced by [`Reconciler::observed_deltas`] and applied in [`crate::service::Service::process_result`].
pub(crate) enum ObservedStateDelta {
Upsert(Box<(NodeId, ObservedStateLocation)>),
Delete(NodeId),
}
impl ObservedStateDelta {
pub(crate) fn node_id(&self) -> &NodeId {
match self {
Self::Upsert(up) => &up.0,
Self::Delete(nid) => nid,
}
}
}
/// When a reconcile task completes, it sends this result object
/// to be applied to the primary TenantShard.
pub(crate) struct ReconcileResult {
@@ -453,7 +437,7 @@ pub(crate) struct ReconcileResult {
pub(crate) tenant_shard_id: TenantShardId,
pub(crate) generation: Option<Generation>,
pub(crate) observed_deltas: Vec<ObservedStateDelta>,
pub(crate) observed: ObservedState,
/// Set [`TenantShard::pending_compute_notification`] from this flag
pub(crate) pending_compute_notification: bool,
@@ -1139,7 +1123,7 @@ impl TenantShard {
result,
tenant_shard_id: reconciler.tenant_shard_id,
generation: reconciler.generation,
observed_deltas: reconciler.observed_deltas(),
observed: reconciler.observed,
pending_compute_notification: reconciler.compute_notify_failure,
}
}
@@ -1193,7 +1177,6 @@ impl TenantShard {
reconciler_config,
config: self.config.clone(),
observed: self.observed.clone(),
original_observed: self.observed.clone(),
compute_hook: compute_hook.clone(),
service_config: service_config.clone(),
_gate_guard: gate_guard,
@@ -1454,62 +1437,6 @@ impl TenantShard {
.map(|(node_id, gen)| (node_id, Generation::new(gen)))
.collect()
}
/// Update the observed state of the tenant by applying incremental deltas
///
/// Deltas are generated by reconcilers via [`Reconciler::observed_deltas`].
/// They are then filtered in [`crate::service::Service::process_result`].
pub(crate) fn apply_observed_deltas(
&mut self,
deltas: impl Iterator<Item = ObservedStateDelta>,
) {
for delta in deltas {
match delta {
ObservedStateDelta::Upsert(ups) => {
let (node_id, loc) = *ups;
// If the generation of the observed location in the delta is lagging
// behind the current one, then we have a race condition and cannot
// be certain about the true observed state. Set the observed state
// to None in order to reflect this.
let crnt_gen = self
.observed
.locations
.get(&node_id)
.and_then(|loc| loc.conf.as_ref())
.and_then(|conf| conf.generation);
let new_gen = loc.conf.as_ref().and_then(|conf| conf.generation);
match (crnt_gen, new_gen) {
(Some(crnt), Some(new)) if crnt_gen > new_gen => {
tracing::warn!(
"Skipping observed state update {}: {:?} and using None due to stale generation ({} > {})",
node_id, loc, crnt, new
);
self.observed
.locations
.insert(node_id, ObservedStateLocation { conf: None });
continue;
}
_ => {}
}
if let Some(conf) = &loc.conf {
tracing::info!("Updating observed location {}: {:?}", node_id, conf);
} else {
tracing::info!("Setting observed location {} to None", node_id,)
}
self.observed.locations.insert(node_id, loc);
}
ObservedStateDelta::Delete(node_id) => {
tracing::info!("Deleting observed location {}", node_id);
self.observed.locations.remove(&node_id);
}
}
}
}
}
#[cfg(test)]

View File

@@ -2,8 +2,6 @@
Run the regression tests on the cloud instance of Neon
"""
from __future__ import annotations
from pathlib import Path
from typing import Any

View File

@@ -1,5 +1,3 @@
from __future__ import annotations
pytest_plugins = (
"fixtures.pg_version",
"fixtures.parametrize",

View File

@@ -1 +0,0 @@
from __future__ import annotations

View File

@@ -1,5 +1,3 @@
from __future__ import annotations
import calendar
import dataclasses
import enum
@@ -10,7 +8,9 @@ import timeit
from contextlib import contextmanager
from datetime import datetime
from pathlib import Path
from typing import TYPE_CHECKING
# Type-related stuff
from typing import Callable, ClassVar, Dict, Iterator, Optional
import allure
import pytest
@@ -23,11 +23,6 @@ from fixtures.common_types import TenantId, TimelineId
from fixtures.log_helper import log
from fixtures.neon_fixtures import NeonPageserver
if TYPE_CHECKING:
from collections.abc import Iterator, Mapping
from typing import Callable, Optional
"""
This file contains fixtures for micro-benchmarks.
@@ -141,30 +136,20 @@ class PgBenchRunResult:
)
# Taken from https://github.com/postgres/postgres/blob/REL_15_1/src/bin/pgbench/pgbench.c#L5144-L5171
#
# This used to be a class variable on PgBenchInitResult. However later versions
# of Python complain:
#
# ValueError: mutable default <class 'dict'> for field EXTRACTORS is not allowed: use default_factory
#
# When you do what the error tells you to do, it seems to fail our Python 3.9
# test environment. So let's just move it to a private module constant, and move
# on.
_PGBENCH_INIT_EXTRACTORS: Mapping[str, re.Pattern[str]] = {
"drop_tables": re.compile(r"drop tables (\d+\.\d+) s"),
"create_tables": re.compile(r"create tables (\d+\.\d+) s"),
"client_side_generate": re.compile(r"client-side generate (\d+\.\d+) s"),
"server_side_generate": re.compile(r"server-side generate (\d+\.\d+) s"),
"vacuum": re.compile(r"vacuum (\d+\.\d+) s"),
"primary_keys": re.compile(r"primary keys (\d+\.\d+) s"),
"foreign_keys": re.compile(r"foreign keys (\d+\.\d+) s"),
"total": re.compile(r"done in (\d+\.\d+) s"), # Total time printed by pgbench
}
@dataclasses.dataclass
class PgBenchInitResult:
# Taken from https://github.com/postgres/postgres/blob/REL_15_1/src/bin/pgbench/pgbench.c#L5144-L5171
EXTRACTORS: ClassVar[Dict[str, re.Pattern]] = { # type: ignore[type-arg]
"drop_tables": re.compile(r"drop tables (\d+\.\d+) s"),
"create_tables": re.compile(r"create tables (\d+\.\d+) s"),
"client_side_generate": re.compile(r"client-side generate (\d+\.\d+) s"),
"server_side_generate": re.compile(r"server-side generate (\d+\.\d+) s"),
"vacuum": re.compile(r"vacuum (\d+\.\d+) s"),
"primary_keys": re.compile(r"primary keys (\d+\.\d+) s"),
"foreign_keys": re.compile(r"foreign keys (\d+\.\d+) s"),
"total": re.compile(r"done in (\d+\.\d+) s"), # Total time printed by pgbench
}
total: Optional[float]
drop_tables: Optional[float]
create_tables: Optional[float]
@@ -190,10 +175,10 @@ class PgBenchInitResult:
last_line = stderr.splitlines()[-1]
timings: dict[str, Optional[float]] = {}
timings: Dict[str, Optional[float]] = {}
last_line_items = re.split(r"\(|\)|,", last_line)
for item in last_line_items:
for key, regex in _PGBENCH_INIT_EXTRACTORS.items():
for key, regex in cls.EXTRACTORS.items():
if (m := regex.match(item.strip())) is not None:
if key in timings:
raise RuntimeError(
@@ -400,7 +385,7 @@ class NeonBenchmarker:
self,
pageserver: NeonPageserver,
metric_name: str,
label_filters: Optional[dict[str, str]] = None,
label_filters: Optional[Dict[str, str]] = None,
) -> int:
"""Fetch the value of given int counter from pageserver metrics."""
all_metrics = pageserver.http_client().get_metrics()

View File

@@ -1,18 +1,10 @@
from __future__ import annotations
import random
from dataclasses import dataclass
from enum import Enum
from functools import total_ordering
from typing import TYPE_CHECKING, TypeVar
from typing_extensions import override
if TYPE_CHECKING:
from typing import Any, Union
T = TypeVar("T", bound="Id")
from typing import Any, Dict, Type, TypeVar, Union
T = TypeVar("T", bound="Id")
DEFAULT_WAL_SEG_SIZE = 16 * 1024 * 1024
@@ -33,41 +25,38 @@ class Lsn:
self.lsn_int = (int(left, 16) << 32) + int(right, 16)
assert 0 <= self.lsn_int <= 0xFFFFFFFF_FFFFFFFF
@override
def __str__(self) -> str:
"""Convert lsn from int to standard hex notation."""
return f"{(self.lsn_int >> 32):X}/{(self.lsn_int & 0xFFFFFFFF):X}"
@override
def __repr__(self) -> str:
return f'Lsn("{str(self)}")'
def __int__(self) -> int:
return self.lsn_int
def __lt__(self, other: object) -> bool:
def __lt__(self, other: Any) -> bool:
if not isinstance(other, Lsn):
return NotImplemented
return self.lsn_int < other.lsn_int
def __gt__(self, other: object) -> bool:
def __gt__(self, other: Any) -> bool:
if not isinstance(other, Lsn):
raise NotImplementedError
return self.lsn_int > other.lsn_int
@override
def __eq__(self, other: object) -> bool:
def __eq__(self, other: Any) -> bool:
if not isinstance(other, Lsn):
return NotImplemented
return self.lsn_int == other.lsn_int
# Returns the difference between two Lsns, in bytes
def __sub__(self, other: object) -> int:
def __sub__(self, other: Any) -> int:
if not isinstance(other, Lsn):
return NotImplemented
return self.lsn_int - other.lsn_int
def __add__(self, other: Union[int, Lsn]) -> Lsn:
def __add__(self, other: Union[int, "Lsn"]) -> "Lsn":
if isinstance(other, int):
return Lsn(self.lsn_int + other)
elif isinstance(other, Lsn):
@@ -75,14 +64,13 @@ class Lsn:
else:
raise NotImplementedError
@override
def __hash__(self) -> int:
return hash(self.lsn_int)
def as_int(self) -> int:
return self.lsn_int
def segment_lsn(self, seg_sz: int = DEFAULT_WAL_SEG_SIZE) -> Lsn:
def segment_lsn(self, seg_sz: int = DEFAULT_WAL_SEG_SIZE) -> "Lsn":
return Lsn(self.lsn_int - (self.lsn_int % seg_sz))
def segno(self, seg_sz: int = DEFAULT_WAL_SEG_SIZE) -> int:
@@ -122,57 +110,48 @@ class Id:
self.id = bytearray.fromhex(x)
assert len(self.id) == 16
@override
def __str__(self) -> str:
return self.id.hex()
def __lt__(self, other: object) -> bool:
def __lt__(self, other) -> bool:
if not isinstance(other, type(self)):
return NotImplemented
return self.id < other.id
@override
def __eq__(self, other: object) -> bool:
def __eq__(self, other) -> bool:
if not isinstance(other, type(self)):
return NotImplemented
return self.id == other.id
@override
def __hash__(self) -> int:
return hash(str(self.id))
@classmethod
def generate(cls: type[T]) -> T:
def generate(cls: Type[T]) -> T:
"""Generate a random ID"""
return cls(random.randbytes(16).hex())
class TenantId(Id):
@override
def __repr__(self) -> str:
return f'`TenantId("{self.id.hex()}")'
@override
def __str__(self) -> str:
return self.id.hex()
class NodeId(Id):
@override
def __repr__(self) -> str:
return f'`NodeId("{self.id.hex()}")'
@override
def __str__(self) -> str:
return self.id.hex()
class TimelineId(Id):
@override
def __repr__(self) -> str:
return f'TimelineId("{self.id.hex()}")'
@override
def __str__(self) -> str:
return self.id.hex()
@@ -183,7 +162,7 @@ class TenantTimelineId:
timeline_id: TimelineId
@classmethod
def from_json(cls, d: dict[str, Any]) -> TenantTimelineId:
def from_json(cls, d: Dict[str, Any]) -> "TenantTimelineId":
return TenantTimelineId(
tenant_id=TenantId(d["tenant_id"]),
timeline_id=TimelineId(d["timeline_id"]),
@@ -202,7 +181,7 @@ class TenantShardId:
assert self.shard_number < self.shard_count or self.shard_count == 0
@classmethod
def parse(cls: type[TTenantShardId], input: str) -> TTenantShardId:
def parse(cls: Type[TTenantShardId], input) -> TTenantShardId:
if len(input) == 32:
return cls(
tenant_id=TenantId(input),
@@ -218,7 +197,6 @@ class TenantShardId:
else:
raise ValueError(f"Invalid TenantShardId '{input}'")
@override
def __str__(self):
if self.shard_count > 0:
return f"{self.tenant_id}-{self.shard_number:02x}{self.shard_count:02x}"
@@ -226,25 +204,22 @@ class TenantShardId:
# Unsharded case: equivalent of Rust TenantShardId::unsharded(tenant_id)
return str(self.tenant_id)
@override
def __repr__(self):
return self.__str__()
def _tuple(self) -> tuple[TenantId, int, int]:
return (self.tenant_id, self.shard_number, self.shard_count)
def __lt__(self, other: object) -> bool:
def __lt__(self, other) -> bool:
if not isinstance(other, type(self)):
return NotImplemented
return self._tuple() < other._tuple()
@override
def __eq__(self, other: object) -> bool:
def __eq__(self, other) -> bool:
if not isinstance(other, type(self)):
return NotImplemented
return self._tuple() == other._tuple()
@override
def __hash__(self) -> int:
return hash(self._tuple())

View File

@@ -1,18 +1,14 @@
from __future__ import annotations
import os
import time
from abc import ABC, abstractmethod
from collections.abc import Iterator
from contextlib import _GeneratorContextManager, contextmanager
# Type-related stuff
from pathlib import Path
from typing import TYPE_CHECKING
from typing import Dict, Iterator, List
import pytest
from _pytest.fixtures import FixtureRequest
from typing_extensions import override
from fixtures.benchmark_fixture import MetricReport, NeonBenchmarker
from fixtures.log_helper import log
@@ -26,9 +22,6 @@ from fixtures.neon_fixtures import (
)
from fixtures.pg_stats import PgStatTable
if TYPE_CHECKING:
from collections.abc import Iterator
class PgCompare(ABC):
"""Common interface of all postgres implementations, useful for benchmarks.
@@ -70,16 +63,16 @@ class PgCompare(ABC):
@contextmanager
@abstractmethod
def record_pageserver_writes(self, out_name: str):
def record_pageserver_writes(self, out_name):
pass
@contextmanager
@abstractmethod
def record_duration(self, out_name: str):
def record_duration(self, out_name):
pass
@contextmanager
def record_pg_stats(self, pg_stats: list[PgStatTable]) -> Iterator[None]:
def record_pg_stats(self, pg_stats: List[PgStatTable]) -> Iterator[None]:
init_data = self._retrieve_pg_stats(pg_stats)
yield
@@ -89,8 +82,8 @@ class PgCompare(ABC):
for k in set(init_data) & set(data):
self.zenbenchmark.record(k, data[k] - init_data[k], "", MetricReport.HIGHER_IS_BETTER)
def _retrieve_pg_stats(self, pg_stats: list[PgStatTable]) -> dict[str, int]:
results: dict[str, int] = {}
def _retrieve_pg_stats(self, pg_stats: List[PgStatTable]) -> Dict[str, int]:
results: Dict[str, int] = {}
with self.pg.connect().cursor() as cur:
for pg_stat in pg_stats:
@@ -127,34 +120,28 @@ class NeonCompare(PgCompare):
self._pg = self.env.endpoints.create_start("main", "main", self.tenant)
@property
@override
def pg(self) -> PgProtocol:
return self._pg
@property
@override
def zenbenchmark(self) -> NeonBenchmarker:
return self._zenbenchmark
@property
@override
def pg_bin(self) -> PgBin:
return self._pg_bin
@override
def flush(self, compact: bool = True, gc: bool = True):
wait_for_last_flush_lsn(self.env, self._pg, self.tenant, self.timeline)
self.pageserver_http_client.timeline_checkpoint(self.tenant, self.timeline, compact=compact)
if gc:
self.pageserver_http_client.timeline_gc(self.tenant, self.timeline, 0)
@override
def compact(self):
self.pageserver_http_client.timeline_compact(
self.tenant, self.timeline, wait_until_uploaded=True
)
@override
def report_peak_memory_use(self):
self.zenbenchmark.record(
"peak_mem",
@@ -163,7 +150,6 @@ class NeonCompare(PgCompare):
report=MetricReport.LOWER_IS_BETTER,
)
@override
def report_size(self):
timeline_size = self.zenbenchmark.get_timeline_size(
self.env.repo_dir, self.tenant, self.timeline
@@ -197,11 +183,9 @@ class NeonCompare(PgCompare):
"num_files_uploaded", total_files, "", report=MetricReport.LOWER_IS_BETTER
)
@override
def record_pageserver_writes(self, out_name: str) -> _GeneratorContextManager[None]:
return self.zenbenchmark.record_pageserver_writes(self.env.pageserver, out_name)
@override
def record_duration(self, out_name: str) -> _GeneratorContextManager[None]:
return self.zenbenchmark.record_duration(out_name)
@@ -225,33 +209,26 @@ class VanillaCompare(PgCompare):
self.cur = self.conn.cursor()
@property
@override
def pg(self) -> VanillaPostgres:
return self._pg
@property
@override
def zenbenchmark(self) -> NeonBenchmarker:
return self._zenbenchmark
@property
@override
def pg_bin(self) -> PgBin:
return self._pg.pg_bin
@override
def flush(self, compact: bool = False, gc: bool = False):
self.cur.execute("checkpoint")
@override
def compact(self):
pass
@override
def report_peak_memory_use(self):
pass # TODO find something
@override
def report_size(self):
data_size = self.pg.get_subdir_size(Path("base"))
self.zenbenchmark.record(
@@ -266,7 +243,6 @@ class VanillaCompare(PgCompare):
def record_pageserver_writes(self, out_name: str) -> Iterator[None]:
yield # Do nothing
@override
def record_duration(self, out_name: str) -> _GeneratorContextManager[None]:
return self.zenbenchmark.record_duration(out_name)
@@ -283,35 +259,28 @@ class RemoteCompare(PgCompare):
self.cur = self.conn.cursor()
@property
@override
def pg(self) -> PgProtocol:
return self._pg
@property
@override
def zenbenchmark(self) -> NeonBenchmarker:
return self._zenbenchmark
@property
@override
def pg_bin(self) -> PgBin:
return self._pg.pg_bin
@override
def flush(self, compact: bool = False, gc: bool = False):
def flush(self):
# TODO: flush the remote pageserver
pass
@override
def compact(self):
pass
@override
def report_peak_memory_use(self):
# TODO: get memory usage from remote pageserver
pass
@override
def report_size(self):
# TODO: get storage size from remote pageserver
pass
@@ -320,7 +289,6 @@ class RemoteCompare(PgCompare):
def record_pageserver_writes(self, out_name: str) -> Iterator[None]:
yield # Do nothing
@override
def record_duration(self, out_name: str) -> _GeneratorContextManager[None]:
return self.zenbenchmark.record_duration(out_name)

View File

@@ -1,31 +1,25 @@
from __future__ import annotations
import concurrent.futures
from typing import TYPE_CHECKING
from typing import Any
import pytest
from pytest_httpserver import HTTPServer
from werkzeug.wrappers.request import Request
from werkzeug.wrappers.response import Response
from fixtures.common_types import TenantId
from fixtures.log_helper import log
if TYPE_CHECKING:
from typing import Any, Callable, Optional
class ComputeReconfigure:
def __init__(self, server: HTTPServer):
def __init__(self, server):
self.server = server
self.control_plane_compute_hook_api = f"http://{server.host}:{server.port}/notify-attach"
self.workloads: dict[TenantId, Any] = {}
self.on_notify: Optional[Callable[[Any], None]] = None
self.workloads = {}
self.on_notify = None
def register_workload(self, workload: Any):
def register_workload(self, workload):
self.workloads[workload.tenant_id] = workload
def register_on_notify(self, fn: Optional[Callable[[Any], None]]):
def register_on_notify(self, fn):
"""
Add some extra work during a notification, like sleeping to slow things down, or
logging what was notified.
@@ -34,7 +28,7 @@ class ComputeReconfigure:
@pytest.fixture(scope="function")
def compute_reconfigure_listener(make_httpserver: HTTPServer):
def compute_reconfigure_listener(make_httpserver):
"""
This fixture exposes an HTTP listener for the storage controller to submit
compute notifications to us, instead of updating neon_local endpoints itself.
@@ -52,7 +46,7 @@ def compute_reconfigure_listener(make_httpserver: HTTPServer):
# accept a healthy rate of calls into notify-attach.
reconfigure_threads = concurrent.futures.ThreadPoolExecutor(max_workers=1)
def handler(request: Request) -> Response:
def handler(request: Request):
assert request.json is not None
body: dict[str, Any] = request.json
log.info(f"notify-attach request: {body}")

View File

@@ -1 +0,0 @@
from __future__ import annotations

View File

@@ -1,5 +1,3 @@
from __future__ import annotations
import requests
from requests.adapters import HTTPAdapter
@@ -23,8 +21,3 @@ class EndpointHttpClient(requests.Session):
res = self.get(f"http://localhost:{self.port}/database_schema?database={database}")
res.raise_for_status()
return res.text
def installed_extensions(self):
res = self.get(f"http://localhost:{self.port}/installed_extensions")
res.raise_for_status()
return res.json()

View File

@@ -1,9 +1,6 @@
from __future__ import annotations
import json
from collections.abc import MutableMapping
from pathlib import Path
from typing import TYPE_CHECKING, cast
from typing import Any, List, MutableMapping, cast
import pytest
from _pytest.config import Config
@@ -13,11 +10,6 @@ from allure_pytest.utils import allure_name, allure_suite_labels
from fixtures.log_helper import log
if TYPE_CHECKING:
from collections.abc import MutableMapping
from typing import Any
"""
The plugin reruns flaky tests.
It uses `pytest.mark.flaky` provided by `pytest-rerunfailures` plugin and flaky tests detected by `scripts/flaky_tests.py`
@@ -35,7 +27,7 @@ def pytest_addoption(parser: Parser):
)
def pytest_collection_modifyitems(config: Config, items: list[pytest.Item]):
def pytest_collection_modifyitems(config: Config, items: List[pytest.Item]):
if not config.getoption("--flaky-tests-json"):
return
@@ -74,5 +66,5 @@ def pytest_collection_modifyitems(config: Config, items: list[pytest.Item]):
# - [2] https://github.com/pytest-dev/pytest-timeout/issues/142
timeout_marker = item.get_closest_marker("timeout")
if timeout_marker is not None:
kwargs = cast("MutableMapping[str, Any]", timeout_marker.kwargs)
kwargs = cast(MutableMapping[str, Any], timeout_marker.kwargs)
kwargs["func_only"] = True

View File

@@ -1,15 +1,8 @@
from __future__ import annotations
from typing import TYPE_CHECKING
from typing import Tuple
import pytest
from pytest_httpserver import HTTPServer
if TYPE_CHECKING:
from collections.abc import Iterator
from fixtures.port_distributor import PortDistributor
# TODO: mypy fails with:
# Module "fixtures.neon_fixtures" does not explicitly export attribute "PortDistributor" [attr-defined]
# from fixtures.neon_fixtures import PortDistributor
@@ -24,7 +17,7 @@ def httpserver_ssl_context():
@pytest.fixture(scope="function")
def make_httpserver(httpserver_listen_address, httpserver_ssl_context) -> Iterator[HTTPServer]:
def make_httpserver(httpserver_listen_address, httpserver_ssl_context):
host, port = httpserver_listen_address
if not host:
host = HTTPServer.DEFAULT_LISTEN_HOST
@@ -40,13 +33,13 @@ def make_httpserver(httpserver_listen_address, httpserver_ssl_context) -> Iterat
@pytest.fixture(scope="function")
def httpserver(make_httpserver: HTTPServer) -> Iterator[HTTPServer]:
def httpserver(make_httpserver):
server = make_httpserver
yield server
server.clear()
@pytest.fixture(scope="function")
def httpserver_listen_address(port_distributor: PortDistributor) -> tuple[str, int]:
def httpserver_listen_address(port_distributor) -> Tuple[str, int]:
port = port_distributor.get_port()
return ("localhost", port)

View File

@@ -1,5 +1,3 @@
from __future__ import annotations
import logging
import logging.config
@@ -31,7 +29,7 @@ LOGGING = {
}
def getLogger(name: str = "root") -> logging.Logger:
def getLogger(name="root") -> logging.Logger:
"""Method to get logger for tests.
Should be used to get correctly initialized logger."""

View File

@@ -1,28 +1,23 @@
from __future__ import annotations
from collections import defaultdict
from typing import TYPE_CHECKING
from typing import Dict, List, Optional, Tuple
from prometheus_client.parser import text_string_to_metric_families
from prometheus_client.samples import Sample
from fixtures.log_helper import log
if TYPE_CHECKING:
from typing import Optional
class Metrics:
metrics: dict[str, list[Sample]]
metrics: Dict[str, List[Sample]]
name: str
def __init__(self, name: str = ""):
self.metrics = defaultdict(list)
self.name = name
def query_all(self, name: str, filter: Optional[dict[str, str]] = None) -> list[Sample]:
def query_all(self, name: str, filter: Optional[Dict[str, str]] = None) -> List[Sample]:
filter = filter or {}
res: list[Sample] = []
res = []
for sample in self.metrics[name]:
try:
@@ -32,7 +27,7 @@ class Metrics:
pass
return res
def query_one(self, name: str, filter: Optional[dict[str, str]] = None) -> Sample:
def query_one(self, name: str, filter: Optional[Dict[str, str]] = None) -> Sample:
res = self.query_all(name, filter or {})
assert len(res) == 1, f"expected single sample for {name} {filter}, found {res}"
return res[0]
@@ -48,7 +43,7 @@ class MetricsGetter:
raise NotImplementedError()
def get_metric_value(
self, name: str, filter: Optional[dict[str, str]] = None
self, name: str, filter: Optional[Dict[str, str]] = None
) -> Optional[float]:
metrics = self.get_metrics()
results = metrics.query_all(name, filter=filter)
@@ -59,8 +54,8 @@ class MetricsGetter:
return results[0].value
def get_metrics_values(
self, names: list[str], filter: Optional[dict[str, str]] = None, absence_ok: bool = False
) -> dict[str, float]:
self, names: list[str], filter: Optional[Dict[str, str]] = None, absence_ok=False
) -> Dict[str, float]:
"""
When fetching multiple named metrics, it is more efficient to use this
than to call `get_metric_value` repeatedly.
@@ -102,7 +97,7 @@ def parse_metrics(text: str, name: str = "") -> Metrics:
return metrics
def histogram(prefix_without_trailing_underscore: str) -> list[str]:
def histogram(prefix_without_trailing_underscore: str) -> List[str]:
assert not prefix_without_trailing_underscore.endswith("_")
return [f"{prefix_without_trailing_underscore}_{x}" for x in ["bucket", "count", "sum"]]
@@ -112,7 +107,7 @@ def counter(name: str) -> str:
return f"{name}_total"
PAGESERVER_PER_TENANT_REMOTE_TIMELINE_CLIENT_METRICS: tuple[str, ...] = (
PAGESERVER_PER_TENANT_REMOTE_TIMELINE_CLIENT_METRICS: Tuple[str, ...] = (
"pageserver_remote_timeline_client_calls_started_total",
"pageserver_remote_timeline_client_calls_finished_total",
"pageserver_remote_physical_size",
@@ -120,7 +115,7 @@ PAGESERVER_PER_TENANT_REMOTE_TIMELINE_CLIENT_METRICS: tuple[str, ...] = (
"pageserver_remote_timeline_client_bytes_finished_total",
)
PAGESERVER_GLOBAL_METRICS: tuple[str, ...] = (
PAGESERVER_GLOBAL_METRICS: Tuple[str, ...] = (
"pageserver_storage_operations_seconds_global_count",
"pageserver_storage_operations_seconds_global_sum",
"pageserver_storage_operations_seconds_global_bucket",
@@ -152,7 +147,7 @@ PAGESERVER_GLOBAL_METRICS: tuple[str, ...] = (
counter("pageserver_tenant_throttling_count_global"),
)
PAGESERVER_PER_TENANT_METRICS: tuple[str, ...] = (
PAGESERVER_PER_TENANT_METRICS: Tuple[str, ...] = (
"pageserver_current_logical_size",
"pageserver_resident_physical_size",
"pageserver_io_operations_bytes_total",

View File

@@ -6,12 +6,12 @@ from typing import TYPE_CHECKING, cast
import requests
if TYPE_CHECKING:
from typing import Any, Literal, Optional
from typing import Any, Dict, Literal, Optional, Union
from fixtures.pg_version import PgVersion
def connection_parameters_to_env(params: dict[str, str]) -> dict[str, str]:
def connection_parameters_to_env(params: Dict[str, str]) -> Dict[str, str]:
return {
"PGHOST": params["host"],
"PGDATABASE": params["database"],
@@ -25,7 +25,9 @@ class NeonAPI:
self.__neon_api_key = neon_api_key
self.__neon_api_base_url = neon_api_base_url.strip("/")
def __request(self, method: str | bytes, endpoint: str, **kwargs: Any) -> requests.Response:
def __request(
self, method: Union[str, bytes], endpoint: str, **kwargs: Any
) -> requests.Response:
if "headers" not in kwargs:
kwargs["headers"] = {}
kwargs["headers"]["Authorization"] = f"Bearer {self.__neon_api_key}"
@@ -39,8 +41,8 @@ class NeonAPI:
branch_name: Optional[str] = None,
branch_role_name: Optional[str] = None,
branch_database_name: Optional[str] = None,
) -> dict[str, Any]:
data: dict[str, Any] = {
) -> Dict[str, Any]:
data: Dict[str, Any] = {
"project": {
"branch": {},
},
@@ -68,9 +70,9 @@ class NeonAPI:
assert resp.status_code == 201
return cast("dict[str, Any]", resp.json())
return cast("Dict[str, Any]", resp.json())
def get_project_details(self, project_id: str) -> dict[str, Any]:
def get_project_details(self, project_id: str) -> Dict[str, Any]:
resp = self.__request(
"GET",
f"/projects/{project_id}",
@@ -80,12 +82,12 @@ class NeonAPI:
},
)
assert resp.status_code == 200
return cast("dict[str, Any]", resp.json())
return cast("Dict[str, Any]", resp.json())
def delete_project(
self,
project_id: str,
) -> dict[str, Any]:
) -> Dict[str, Any]:
resp = self.__request(
"DELETE",
f"/projects/{project_id}",
@@ -97,13 +99,13 @@ class NeonAPI:
assert resp.status_code == 200
return cast("dict[str, Any]", resp.json())
return cast("Dict[str, Any]", resp.json())
def start_endpoint(
self,
project_id: str,
endpoint_id: str,
) -> dict[str, Any]:
) -> Dict[str, Any]:
resp = self.__request(
"POST",
f"/projects/{project_id}/endpoints/{endpoint_id}/start",
@@ -114,13 +116,13 @@ class NeonAPI:
assert resp.status_code == 200
return cast("dict[str, Any]", resp.json())
return cast("Dict[str, Any]", resp.json())
def suspend_endpoint(
self,
project_id: str,
endpoint_id: str,
) -> dict[str, Any]:
) -> Dict[str, Any]:
resp = self.__request(
"POST",
f"/projects/{project_id}/endpoints/{endpoint_id}/suspend",
@@ -131,13 +133,13 @@ class NeonAPI:
assert resp.status_code == 200
return cast("dict[str, Any]", resp.json())
return cast("Dict[str, Any]", resp.json())
def restart_endpoint(
self,
project_id: str,
endpoint_id: str,
) -> dict[str, Any]:
) -> Dict[str, Any]:
resp = self.__request(
"POST",
f"/projects/{project_id}/endpoints/{endpoint_id}/restart",
@@ -148,16 +150,16 @@ class NeonAPI:
assert resp.status_code == 200
return cast("dict[str, Any]", resp.json())
return cast("Dict[str, Any]", resp.json())
def create_endpoint(
self,
project_id: str,
branch_id: str,
endpoint_type: Literal["read_write", "read_only"],
settings: dict[str, Any],
) -> dict[str, Any]:
data: dict[str, Any] = {
settings: Dict[str, Any],
) -> Dict[str, Any]:
data: Dict[str, Any] = {
"endpoint": {
"branch_id": branch_id,
},
@@ -180,17 +182,17 @@ class NeonAPI:
assert resp.status_code == 201
return cast("dict[str, Any]", resp.json())
return cast("Dict[str, Any]", resp.json())
def get_connection_uri(
self,
project_id: str,
branch_id: str | None = None,
endpoint_id: str | None = None,
branch_id: Optional[str] = None,
endpoint_id: Optional[str] = None,
database_name: str = "neondb",
role_name: str = "neondb_owner",
pooled: bool = True,
) -> dict[str, Any]:
) -> Dict[str, Any]:
resp = self.__request(
"GET",
f"/projects/{project_id}/connection_uri",
@@ -208,9 +210,9 @@ class NeonAPI:
assert resp.status_code == 200
return cast("dict[str, Any]", resp.json())
return cast("Dict[str, Any]", resp.json())
def get_branches(self, project_id: str) -> dict[str, Any]:
def get_branches(self, project_id: str) -> Dict[str, Any]:
resp = self.__request(
"GET",
f"/projects/{project_id}/branches",
@@ -221,9 +223,9 @@ class NeonAPI:
assert resp.status_code == 200
return cast("dict[str, Any]", resp.json())
return cast("Dict[str, Any]", resp.json())
def get_endpoints(self, project_id: str) -> dict[str, Any]:
def get_endpoints(self, project_id: str) -> Dict[str, Any]:
resp = self.__request(
"GET",
f"/projects/{project_id}/endpoints",
@@ -234,9 +236,9 @@ class NeonAPI:
assert resp.status_code == 200
return cast("dict[str, Any]", resp.json())
return cast("Dict[str, Any]", resp.json())
def get_operations(self, project_id: str) -> dict[str, Any]:
def get_operations(self, project_id: str) -> Dict[str, Any]:
resp = self.__request(
"GET",
f"/projects/{project_id}/operations",
@@ -248,7 +250,7 @@ class NeonAPI:
assert resp.status_code == 200
return cast("dict[str, Any]", resp.json())
return cast("Dict[str, Any]", resp.json())
def wait_for_operation_to_finish(self, project_id: str):
has_running = True
@@ -262,7 +264,7 @@ class NeonAPI:
class NeonApiEndpoint:
def __init__(self, neon_api: NeonAPI, pg_version: PgVersion, project_id: str | None):
def __init__(self, neon_api: NeonAPI, pg_version: PgVersion, project_id: Optional[str]):
self.neon_api = neon_api
if project_id is None:
project = neon_api.create_project(pg_version)

View File

@@ -9,7 +9,15 @@ import tempfile
import textwrap
from itertools import chain, product
from pathlib import Path
from typing import TYPE_CHECKING, cast
from typing import (
Any,
Dict,
List,
Optional,
Tuple,
TypeVar,
cast,
)
import toml
@@ -19,15 +27,7 @@ from fixtures.pageserver.common_types import IndexPartDump
from fixtures.pg_version import PgVersion
from fixtures.utils import AuxFileStore
if TYPE_CHECKING:
from typing import (
Any,
Optional,
TypeVar,
cast,
)
T = TypeVar("T")
T = TypeVar("T")
class AbstractNeonCli(abc.ABC):
@@ -37,7 +37,7 @@ class AbstractNeonCli(abc.ABC):
Do not use directly, use specific subclasses instead.
"""
def __init__(self, extra_env: Optional[dict[str, str]], binpath: Path):
def __init__(self, extra_env: Optional[Dict[str, str]], binpath: Path):
self.extra_env = extra_env
self.binpath = binpath
@@ -45,11 +45,11 @@ class AbstractNeonCli(abc.ABC):
def raw_cli(
self,
arguments: list[str],
extra_env_vars: Optional[dict[str, str]] = None,
arguments: List[str],
extra_env_vars: Optional[Dict[str, str]] = None,
check_return_code=True,
timeout=None,
) -> subprocess.CompletedProcess[str]:
) -> "subprocess.CompletedProcess[str]":
"""
Run the command with the specified arguments.
@@ -92,8 +92,9 @@ class AbstractNeonCli(abc.ABC):
args,
env=env_vars,
check=False,
text=True,
capture_output=True,
universal_newlines=True,
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
timeout=timeout,
)
except subprocess.TimeoutExpired as e:
@@ -117,7 +118,7 @@ class AbstractNeonCli(abc.ABC):
if len(lines) < 2:
log.debug(f"Run {res.args} success: {stripped}")
else:
log.debug("Run %s success:\n%s", res.args, textwrap.indent(stripped, indent))
log.debug("Run %s success:\n%s" % (res.args, textwrap.indent(stripped, indent)))
elif check_return_code:
# this way command output will be in recorded and shown in CI in failure message
indent = indent * 2
@@ -174,7 +175,7 @@ class NeonLocalCli(AbstractNeonCli):
def __init__(
self,
extra_env: Optional[dict[str, str]],
extra_env: Optional[Dict[str, str]],
binpath: Path,
repo_dir: Path,
pg_distrib_dir: Path,
@@ -196,7 +197,7 @@ class NeonLocalCli(AbstractNeonCli):
tenant_id: TenantId,
timeline_id: TimelineId,
pg_version: PgVersion,
conf: Optional[dict[str, Any]] = None,
conf: Optional[Dict[str, Any]] = None,
shard_count: Optional[int] = None,
shard_stripe_size: Optional[int] = None,
placement_policy: Optional[str] = None,
@@ -257,7 +258,7 @@ class NeonLocalCli(AbstractNeonCli):
res = self.raw_cli(["tenant", "set-default", "--tenant-id", str(tenant_id)])
res.check_returncode()
def tenant_config(self, tenant_id: TenantId, conf: dict[str, str]):
def tenant_config(self, tenant_id: TenantId, conf: Dict[str, str]):
"""
Update tenant config.
"""
@@ -273,7 +274,7 @@ class NeonLocalCli(AbstractNeonCli):
res = self.raw_cli(args)
res.check_returncode()
def tenant_list(self) -> subprocess.CompletedProcess[str]:
def tenant_list(self) -> "subprocess.CompletedProcess[str]":
res = self.raw_cli(["tenant", "list"])
res.check_returncode()
return res
@@ -367,7 +368,7 @@ class NeonLocalCli(AbstractNeonCli):
res = self.raw_cli(cmd)
res.check_returncode()
def timeline_list(self, tenant_id: TenantId) -> list[tuple[str, TimelineId]]:
def timeline_list(self, tenant_id: TenantId) -> List[Tuple[str, TimelineId]]:
"""
Returns a list of (branch_name, timeline_id) tuples out of parsed `neon timeline list` CLI output.
"""
@@ -388,9 +389,9 @@ class NeonLocalCli(AbstractNeonCli):
def init(
self,
init_config: dict[str, Any],
init_config: Dict[str, Any],
force: Optional[str] = None,
) -> subprocess.CompletedProcess[str]:
) -> "subprocess.CompletedProcess[str]":
with tempfile.NamedTemporaryFile(mode="w+") as init_config_tmpfile:
init_config_tmpfile.write(toml.dumps(init_config))
init_config_tmpfile.flush()
@@ -433,28 +434,29 @@ class NeonLocalCli(AbstractNeonCli):
def pageserver_start(
self,
id: int,
extra_env_vars: Optional[dict[str, str]] = None,
extra_env_vars: Optional[Dict[str, str]] = None,
timeout_in_seconds: Optional[int] = None,
) -> subprocess.CompletedProcess[str]:
) -> "subprocess.CompletedProcess[str]":
start_args = ["pageserver", "start", f"--id={id}"]
if timeout_in_seconds is not None:
start_args.append(f"--start-timeout={timeout_in_seconds}s")
return self.raw_cli(start_args, extra_env_vars=extra_env_vars)
def pageserver_stop(self, id: int, immediate=False) -> subprocess.CompletedProcess[str]:
def pageserver_stop(self, id: int, immediate=False) -> "subprocess.CompletedProcess[str]":
cmd = ["pageserver", "stop", f"--id={id}"]
if immediate:
cmd.extend(["-m", "immediate"])
log.info(f"Stopping pageserver with {cmd}")
return self.raw_cli(cmd)
def safekeeper_start(
self,
id: int,
extra_opts: Optional[list[str]] = None,
extra_env_vars: Optional[dict[str, str]] = None,
extra_opts: Optional[List[str]] = None,
extra_env_vars: Optional[Dict[str, str]] = None,
timeout_in_seconds: Optional[int] = None,
) -> subprocess.CompletedProcess[str]:
) -> "subprocess.CompletedProcess[str]":
if extra_opts is not None:
extra_opts = [f"-e={opt}" for opt in extra_opts]
else:
@@ -467,7 +469,7 @@ class NeonLocalCli(AbstractNeonCli):
def safekeeper_stop(
self, id: Optional[int] = None, immediate=False
) -> subprocess.CompletedProcess[str]:
) -> "subprocess.CompletedProcess[str]":
args = ["safekeeper", "stop"]
if id is not None:
args.append(str(id))
@@ -477,13 +479,13 @@ class NeonLocalCli(AbstractNeonCli):
def storage_broker_start(
self, timeout_in_seconds: Optional[int] = None
) -> subprocess.CompletedProcess[str]:
) -> "subprocess.CompletedProcess[str]":
cmd = ["storage_broker", "start"]
if timeout_in_seconds is not None:
cmd.append(f"--start-timeout={timeout_in_seconds}s")
return self.raw_cli(cmd)
def storage_broker_stop(self) -> subprocess.CompletedProcess[str]:
def storage_broker_stop(self) -> "subprocess.CompletedProcess[str]":
cmd = ["storage_broker", "stop"]
return self.raw_cli(cmd)
@@ -499,7 +501,7 @@ class NeonLocalCli(AbstractNeonCli):
lsn: Optional[Lsn] = None,
pageserver_id: Optional[int] = None,
allow_multiple=False,
) -> subprocess.CompletedProcess[str]:
) -> "subprocess.CompletedProcess[str]":
args = [
"endpoint",
"create",
@@ -532,12 +534,12 @@ class NeonLocalCli(AbstractNeonCli):
def endpoint_start(
self,
endpoint_id: str,
safekeepers: Optional[list[int]] = None,
safekeepers: Optional[List[int]] = None,
remote_ext_config: Optional[str] = None,
pageserver_id: Optional[int] = None,
allow_multiple=False,
basebackup_request_tries: Optional[int] = None,
) -> subprocess.CompletedProcess[str]:
) -> "subprocess.CompletedProcess[str]":
args = [
"endpoint",
"start",
@@ -566,9 +568,9 @@ class NeonLocalCli(AbstractNeonCli):
endpoint_id: str,
tenant_id: Optional[TenantId] = None,
pageserver_id: Optional[int] = None,
safekeepers: Optional[list[int]] = None,
safekeepers: Optional[List[int]] = None,
check_return_code=True,
) -> subprocess.CompletedProcess[str]:
) -> "subprocess.CompletedProcess[str]":
args = ["endpoint", "reconfigure", endpoint_id]
if tenant_id is not None:
args.extend(["--tenant-id", str(tenant_id)])
@@ -584,7 +586,7 @@ class NeonLocalCli(AbstractNeonCli):
destroy=False,
check_return_code=True,
mode: Optional[str] = None,
) -> subprocess.CompletedProcess[str]:
) -> "subprocess.CompletedProcess[str]":
args = [
"endpoint",
"stop",
@@ -600,7 +602,7 @@ class NeonLocalCli(AbstractNeonCli):
def mappings_map_branch(
self, name: str, tenant_id: TenantId, timeline_id: TimelineId
) -> subprocess.CompletedProcess[str]:
) -> "subprocess.CompletedProcess[str]":
"""
Map tenant id and timeline id to a neon_local branch name. They do not have to exist.
Usually needed when creating branches via PageserverHttpClient and not neon_local.
@@ -621,10 +623,10 @@ class NeonLocalCli(AbstractNeonCli):
return self.raw_cli(args, check_return_code=True)
def start(self, check_return_code=True) -> subprocess.CompletedProcess[str]:
def start(self, check_return_code=True) -> "subprocess.CompletedProcess[str]":
return self.raw_cli(["start"], check_return_code=check_return_code)
def stop(self, check_return_code=True) -> subprocess.CompletedProcess[str]:
def stop(self, check_return_code=True) -> "subprocess.CompletedProcess[str]":
return self.raw_cli(["stop"], check_return_code=check_return_code)
@@ -636,7 +638,7 @@ class WalCraft(AbstractNeonCli):
COMMAND = "wal_craft"
def postgres_config(self) -> list[str]:
def postgres_config(self) -> List[str]:
res = self.raw_cli(["print-postgres-config"])
res.check_returncode()
return res.stdout.split("\n")

View File

@@ -13,7 +13,6 @@ import threading
import time
import uuid
from collections import defaultdict
from collections.abc import Iterable, Iterator
from contextlib import closing, contextmanager
from dataclasses import dataclass
from datetime import datetime
@@ -22,7 +21,20 @@ from fcntl import LOCK_EX, LOCK_UN, flock
from functools import cached_property
from pathlib import Path
from types import TracebackType
from typing import TYPE_CHECKING, cast
from typing import (
Any,
Callable,
Dict,
Iterable,
Iterator,
List,
Optional,
Tuple,
Type,
TypeVar,
Union,
cast,
)
from urllib.parse import quote, urlparse
import asyncpg
@@ -79,6 +91,7 @@ from fixtures.utils import (
allure_attach_from_dir,
assert_no_errors,
get_dir_size,
get_self_dir,
print_gc_result,
subprocess_capture,
wait_until,
@@ -87,17 +100,7 @@ from fixtures.utils import AuxFileStore as AuxFileStore # reexport
from .neon_api import NeonAPI, NeonApiEndpoint
if TYPE_CHECKING:
from typing import (
Any,
Callable,
Optional,
TypeVar,
Union,
)
T = TypeVar("T")
T = TypeVar("T")
"""
This file contains pytest fixtures. A fixture is a test resource that can be
@@ -116,7 +119,7 @@ Don't import functions from this file, or pytest will emit warnings. Instead
put directly-importable functions into utils.py or another separate file.
"""
Env = dict[str, str]
Env = Dict[str, str]
DEFAULT_OUTPUT_DIR: str = "test_output"
DEFAULT_BRANCH_NAME: str = "main"
@@ -127,7 +130,7 @@ BASE_PORT: int = 15000
@pytest.fixture(scope="session")
def base_dir() -> Iterator[Path]:
# find the base directory (currently this is the git root)
base_dir = Path(__file__).parents[2]
base_dir = get_self_dir().parent.parent
log.info(f"base_dir is {base_dir}")
yield base_dir
@@ -248,7 +251,7 @@ class PgProtocol:
"""
return str(make_dsn(**self.conn_options(**kwargs)))
def conn_options(self, **kwargs: Any) -> dict[str, Any]:
def conn_options(self, **kwargs: Any) -> Dict[str, Any]:
"""
Construct a dictionary of connection options from default values and extra parameters.
An option can be dropped from the returning dictionary by None-valued extra parameter.
@@ -317,7 +320,7 @@ class PgProtocol:
conn_options["server_settings"] = {key: val}
return await asyncpg.connect(**conn_options)
def safe_psql(self, query: str, **kwargs: Any) -> list[tuple[Any, ...]]:
def safe_psql(self, query: str, **kwargs: Any) -> List[Tuple[Any, ...]]:
"""
Execute query against the node and return all rows.
This method passes all extra params to connstr.
@@ -326,12 +329,12 @@ class PgProtocol:
def safe_psql_many(
self, queries: Iterable[str], log_query=True, **kwargs: Any
) -> list[list[tuple[Any, ...]]]:
) -> List[List[Tuple[Any, ...]]]:
"""
Execute queries against the node and return all rows.
This method passes all extra params to connstr.
"""
result: list[list[Any]] = []
result: List[List[Any]] = []
with closing(self.connect(**kwargs)) as conn:
with conn.cursor() as cur:
for query in queries:
@@ -377,7 +380,7 @@ class NeonEnvBuilder:
test_overlay_dir: Optional[Path] = None,
pageserver_remote_storage: Optional[RemoteStorage] = None,
# toml that will be decomposed into `--config-override` flags during `pageserver --init`
pageserver_config_override: Optional[str | Callable[[dict[str, Any]], None]] = None,
pageserver_config_override: Optional[str | Callable[[Dict[str, Any]], None]] = None,
num_safekeepers: int = 1,
num_pageservers: int = 1,
# Use non-standard SK ids to check for various parsing bugs
@@ -392,7 +395,7 @@ class NeonEnvBuilder:
initial_timeline: Optional[TimelineId] = None,
pageserver_virtual_file_io_engine: Optional[str] = None,
pageserver_aux_file_policy: Optional[AuxFileStore] = None,
pageserver_default_tenant_config_compaction_algorithm: Optional[dict[str, Any]] = None,
pageserver_default_tenant_config_compaction_algorithm: Optional[Dict[str, Any]] = None,
safekeeper_extra_opts: Optional[list[str]] = None,
storage_controller_port_override: Optional[int] = None,
pageserver_virtual_file_io_mode: Optional[str] = None,
@@ -427,7 +430,7 @@ class NeonEnvBuilder:
self.enable_scrub_on_exit = True
self.test_output_dir = test_output_dir
self.test_overlay_dir = test_overlay_dir
self.overlay_mounts_created_by_us: list[tuple[str, Path]] = []
self.overlay_mounts_created_by_us: List[Tuple[str, Path]] = []
self.config_init_force: Optional[str] = None
self.top_output_dir = top_output_dir
self.control_plane_compute_hook_api: Optional[str] = None
@@ -436,7 +439,7 @@ class NeonEnvBuilder:
self.pageserver_virtual_file_io_engine: Optional[str] = pageserver_virtual_file_io_engine
self.pageserver_default_tenant_config_compaction_algorithm: Optional[
dict[str, Any]
Dict[str, Any]
] = pageserver_default_tenant_config_compaction_algorithm
if self.pageserver_default_tenant_config_compaction_algorithm is not None:
log.debug(
@@ -466,7 +469,7 @@ class NeonEnvBuilder:
def init_start(
self,
initial_tenant_conf: Optional[dict[str, Any]] = None,
initial_tenant_conf: Optional[Dict[str, Any]] = None,
default_remote_storage_if_missing: bool = True,
initial_tenant_shard_count: Optional[int] = None,
initial_tenant_shard_stripe_size: Optional[int] = None,
@@ -821,7 +824,7 @@ class NeonEnvBuilder:
overlayfs_mounts = {mountpoint for _, mountpoint in self.overlay_mounts_created_by_us}
directories_to_clean: list[Path] = []
directories_to_clean: List[Path] = []
for test_entry in Path(self.repo_dir).glob("**/*"):
if test_entry in overlayfs_mounts:
continue
@@ -852,12 +855,12 @@ class NeonEnvBuilder:
if isinstance(x, S3Storage):
x.do_cleanup()
def __enter__(self) -> NeonEnvBuilder:
def __enter__(self) -> "NeonEnvBuilder":
return self
def __exit__(
self,
exc_type: Optional[type[BaseException]],
exc_type: Optional[Type[BaseException]],
exc_value: Optional[BaseException],
traceback: Optional[TracebackType],
):
@@ -968,8 +971,8 @@ class NeonEnv:
self.port_distributor = config.port_distributor
self.s3_mock_server = config.mock_s3_server
self.endpoints = EndpointFactory(self)
self.safekeepers: list[Safekeeper] = []
self.pageservers: list[NeonPageserver] = []
self.safekeepers: List[Safekeeper] = []
self.pageservers: List[NeonPageserver] = []
self.broker = NeonBroker(self)
self.pageserver_remote_storage = config.pageserver_remote_storage
self.safekeepers_remote_storage = config.safekeepers_remote_storage
@@ -1041,7 +1044,7 @@ class NeonEnv:
self.pageserver_virtual_file_io_mode = config.pageserver_virtual_file_io_mode
# Create the neon_local's `NeonLocalInitConf`
cfg: dict[str, Any] = {
cfg: Dict[str, Any] = {
"default_tenant_id": str(self.initial_tenant),
"broker": {
"listen_addr": self.broker.listen_addr(),
@@ -1070,7 +1073,7 @@ class NeonEnv:
http=self.port_distributor.get_port(),
)
ps_cfg: dict[str, Any] = {
ps_cfg: Dict[str, Any] = {
"id": ps_id,
"listen_pg_addr": f"localhost:{pageserver_port.pg}",
"listen_http_addr": f"localhost:{pageserver_port.http}",
@@ -1119,7 +1122,7 @@ class NeonEnv:
http=self.port_distributor.get_port(),
)
id = config.safekeepers_id_start + i # assign ids sequentially
sk_cfg: dict[str, Any] = {
sk_cfg: Dict[str, Any] = {
"id": id,
"pg_port": port.pg,
"pg_tenant_only_port": port.pg_tenant_only,
@@ -1284,8 +1287,9 @@ class NeonEnv:
res = subprocess.run(
[bin_pageserver, "--version"],
check=True,
text=True,
capture_output=True,
universal_newlines=True,
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
)
return res.stdout
@@ -1328,13 +1332,13 @@ class NeonEnv:
self,
tenant_id: Optional[TenantId] = None,
timeline_id: Optional[TimelineId] = None,
conf: Optional[dict[str, Any]] = None,
conf: Optional[Dict[str, Any]] = None,
shard_count: Optional[int] = None,
shard_stripe_size: Optional[int] = None,
placement_policy: Optional[str] = None,
set_default: bool = False,
aux_file_policy: Optional[AuxFileStore] = None,
) -> tuple[TenantId, TimelineId]:
) -> Tuple[TenantId, TimelineId]:
"""
Creates a new tenant, returns its id and its initial timeline's id.
"""
@@ -1355,7 +1359,7 @@ class NeonEnv:
return tenant_id, timeline_id
def config_tenant(self, tenant_id: Optional[TenantId], conf: dict[str, str]):
def config_tenant(self, tenant_id: Optional[TenantId], conf: Dict[str, str]):
"""
Update tenant config.
"""
@@ -1407,7 +1411,7 @@ def neon_simple_env(
pg_version: PgVersion,
pageserver_virtual_file_io_engine: str,
pageserver_aux_file_policy: Optional[AuxFileStore],
pageserver_default_tenant_config_compaction_algorithm: Optional[dict[str, Any]],
pageserver_default_tenant_config_compaction_algorithm: Optional[Dict[str, Any]],
pageserver_virtual_file_io_mode: Optional[str],
) -> Iterator[NeonEnv]:
"""
@@ -1455,7 +1459,7 @@ def neon_env_builder(
test_overlay_dir: Path,
top_output_dir: Path,
pageserver_virtual_file_io_engine: str,
pageserver_default_tenant_config_compaction_algorithm: Optional[dict[str, Any]],
pageserver_default_tenant_config_compaction_algorithm: Optional[Dict[str, Any]],
pageserver_aux_file_policy: Optional[AuxFileStore],
record_property: Callable[[str, object], None],
pageserver_virtual_file_io_mode: Optional[str],
@@ -1517,7 +1521,7 @@ class LogUtils:
def assert_log_contains(
self, pattern: str, offset: None | LogCursor = None
) -> tuple[str, LogCursor]:
) -> Tuple[str, LogCursor]:
"""Convenient for use inside wait_until()"""
res = self.log_contains(pattern, offset=offset)
@@ -1526,7 +1530,7 @@ class LogUtils:
def log_contains(
self, pattern: str, offset: None | LogCursor = None
) -> Optional[tuple[str, LogCursor]]:
) -> Optional[Tuple[str, LogCursor]]:
"""Check that the log contains a line that matches the given regex"""
logfile = self.logfile
if not logfile.exists():
@@ -1607,7 +1611,7 @@ class NeonStorageController(MetricsGetter, LogUtils):
self.running = True
return self
def stop(self, immediate: bool = False) -> NeonStorageController:
def stop(self, immediate: bool = False) -> "NeonStorageController":
if self.running:
self.env.neon_cli.storage_controller_stop(immediate)
self.running = False
@@ -1669,7 +1673,7 @@ class NeonStorageController(MetricsGetter, LogUtils):
return resp
def headers(self, scope: Optional[TokenScope]) -> dict[str, str]:
def headers(self, scope: Optional[TokenScope]) -> Dict[str, str]:
headers = {}
if self.auth_enabled and scope is not None:
jwt_token = self.env.auth_keys.generate_token(scope=scope)
@@ -1855,13 +1859,13 @@ class NeonStorageController(MetricsGetter, LogUtils):
tenant_id: TenantId,
shard_count: Optional[int] = None,
shard_stripe_size: Optional[int] = None,
tenant_config: Optional[dict[Any, Any]] = None,
placement_policy: Optional[Union[dict[Any, Any] | str]] = None,
tenant_config: Optional[Dict[Any, Any]] = None,
placement_policy: Optional[Union[Dict[Any, Any] | str]] = None,
):
"""
Use this rather than pageserver_api() when you need to include shard parameters
"""
body: dict[str, Any] = {"new_tenant_id": str(tenant_id)}
body: Dict[str, Any] = {"new_tenant_id": str(tenant_id)}
if shard_count is not None:
shard_params = {"count": shard_count}
@@ -2077,8 +2081,8 @@ class NeonStorageController(MetricsGetter, LogUtils):
time.sleep(backoff)
def metadata_health_update(self, healthy: list[TenantShardId], unhealthy: list[TenantShardId]):
body: dict[str, Any] = {
def metadata_health_update(self, healthy: List[TenantShardId], unhealthy: List[TenantShardId]):
body: Dict[str, Any] = {
"healthy_tenant_shards": [str(t) for t in healthy],
"unhealthy_tenant_shards": [str(t) for t in unhealthy],
}
@@ -2099,7 +2103,7 @@ class NeonStorageController(MetricsGetter, LogUtils):
return response.json()
def metadata_health_list_outdated(self, duration: str):
body: dict[str, Any] = {"not_scrubbed_for": duration}
body: Dict[str, Any] = {"not_scrubbed_for": duration}
response = self.request(
"POST",
@@ -2133,7 +2137,7 @@ class NeonStorageController(MetricsGetter, LogUtils):
response.raise_for_status()
return response.json()
def configure_failpoints(self, config_strings: tuple[str, str] | list[tuple[str, str]]):
def configure_failpoints(self, config_strings: Tuple[str, str] | List[Tuple[str, str]]):
if isinstance(config_strings, tuple):
pairs = [config_strings]
else:
@@ -2150,13 +2154,13 @@ class NeonStorageController(MetricsGetter, LogUtils):
log.info(f"Got failpoints request response code {res.status_code}")
res.raise_for_status()
def get_tenants_placement(self) -> defaultdict[str, dict[str, Any]]:
def get_tenants_placement(self) -> defaultdict[str, Dict[str, Any]]:
"""
Get the intent and observed placements of all tenants known to the storage controller.
"""
tenants = self.tenant_list()
tenant_placement: defaultdict[str, dict[str, Any]] = defaultdict(
tenant_placement: defaultdict[str, Dict[str, Any]] = defaultdict(
lambda: {
"observed": {"attached": None, "secondary": []},
"intent": {"attached": None, "secondary": []},
@@ -2263,12 +2267,12 @@ class NeonStorageController(MetricsGetter, LogUtils):
response.raise_for_status()
return [TenantShardId.parse(tid) for tid in response.json()["updated"]]
def __enter__(self) -> NeonStorageController:
def __enter__(self) -> "NeonStorageController":
return self
def __exit__(
self,
exc_type: Optional[type[BaseException]],
exc_type: Optional[Type[BaseException]],
exc: Optional[BaseException],
tb: Optional[TracebackType],
):
@@ -2277,7 +2281,7 @@ class NeonStorageController(MetricsGetter, LogUtils):
class NeonProxiedStorageController(NeonStorageController):
def __init__(self, env: NeonEnv, proxy_port: int, auth_enabled: bool):
super().__init__(env, proxy_port, auth_enabled)
super(NeonProxiedStorageController, self).__init__(env, proxy_port, auth_enabled)
self.instances: dict[int, dict[str, Any]] = {}
def start(
@@ -2296,7 +2300,7 @@ class NeonProxiedStorageController(NeonStorageController):
def stop_instance(
self, immediate: bool = False, instance_id: Optional[int] = None
) -> NeonStorageController:
) -> "NeonStorageController":
assert instance_id in self.instances
if self.instances[instance_id]["running"]:
self.env.neon_cli.storage_controller_stop(immediate, instance_id)
@@ -2305,7 +2309,7 @@ class NeonProxiedStorageController(NeonStorageController):
self.running = any(meta["running"] for meta in self.instances.values())
return self
def stop(self, immediate: bool = False) -> NeonStorageController:
def stop(self, immediate: bool = False) -> "NeonStorageController":
for iid, details in self.instances.items():
if details["running"]:
self.env.neon_cli.storage_controller_stop(immediate, iid)
@@ -2324,7 +2328,7 @@ class NeonProxiedStorageController(NeonStorageController):
def log_contains(
self, pattern: str, offset: None | LogCursor = None
) -> Optional[tuple[str, LogCursor]]:
) -> Optional[Tuple[str, LogCursor]]:
raise NotImplementedError()
@@ -2356,7 +2360,7 @@ class NeonPageserver(PgProtocol, LogUtils):
# env.pageserver.allowed_errors.append(".*could not open garage door.*")
#
# The entries in the list are regular experessions.
self.allowed_errors: list[str] = list(DEFAULT_PAGESERVER_ALLOWED_ERRORS)
self.allowed_errors: List[str] = list(DEFAULT_PAGESERVER_ALLOWED_ERRORS)
def timeline_dir(
self,
@@ -2381,19 +2385,19 @@ class NeonPageserver(PgProtocol, LogUtils):
def config_toml_path(self) -> Path:
return self.workdir / "pageserver.toml"
def edit_config_toml(self, edit_fn: Callable[[dict[str, Any]], T]) -> T:
def edit_config_toml(self, edit_fn: Callable[[Dict[str, Any]], T]) -> T:
"""
Edit the pageserver's config toml file in place.
"""
path = self.config_toml_path
with open(path) as f:
with open(path, "r") as f:
config = toml.load(f)
res = edit_fn(config)
with open(path, "w") as f:
toml.dump(config, f)
return res
def patch_config_toml_nonrecursive(self, patch: dict[str, Any]) -> dict[str, Any]:
def patch_config_toml_nonrecursive(self, patch: Dict[str, Any]) -> Dict[str, Any]:
"""
Non-recursively merge the given `patch` dict into the existing config toml, using `dict.update()`.
Returns the replaced values.
@@ -2402,7 +2406,7 @@ class NeonPageserver(PgProtocol, LogUtils):
"""
replacements = {}
def doit(config: dict[str, Any]):
def doit(config: Dict[str, Any]):
while len(patch) > 0:
key, new = patch.popitem()
old = config.get(key, None)
@@ -2414,9 +2418,9 @@ class NeonPageserver(PgProtocol, LogUtils):
def start(
self,
extra_env_vars: Optional[dict[str, str]] = None,
extra_env_vars: Optional[Dict[str, str]] = None,
timeout_in_seconds: Optional[int] = None,
) -> NeonPageserver:
) -> "NeonPageserver":
"""
Start the page server.
`overrides` allows to add some config to this pageserver start.
@@ -2442,7 +2446,7 @@ class NeonPageserver(PgProtocol, LogUtils):
return self
def stop(self, immediate: bool = False) -> NeonPageserver:
def stop(self, immediate: bool = False) -> "NeonPageserver":
"""
Stop the page server.
Returns self.
@@ -2490,12 +2494,12 @@ class NeonPageserver(PgProtocol, LogUtils):
wait_until(20, 0.5, complete)
def __enter__(self) -> NeonPageserver:
def __enter__(self) -> "NeonPageserver":
return self
def __exit__(
self,
exc_type: Optional[type[BaseException]],
exc_type: Optional[Type[BaseException]],
exc: Optional[BaseException],
tb: Optional[TracebackType],
):
@@ -2542,7 +2546,7 @@ class NeonPageserver(PgProtocol, LogUtils):
def tenant_attach(
self,
tenant_id: TenantId,
config: None | dict[str, Any] = None,
config: None | Dict[str, Any] = None,
generation: Optional[int] = None,
override_storage_controller_generation: bool = False,
):
@@ -2581,7 +2585,7 @@ class NeonPageserver(PgProtocol, LogUtils):
) -> dict[str, Any]:
path = self.tenant_dir(tenant_shard_id) / "config-v1"
log.info(f"Reading location conf from {path}")
bytes = open(path).read()
bytes = open(path, "r").read()
try:
decoded: dict[str, Any] = toml.loads(bytes)
return decoded
@@ -2592,7 +2596,7 @@ class NeonPageserver(PgProtocol, LogUtils):
def tenant_create(
self,
tenant_id: TenantId,
conf: Optional[dict[str, Any]] = None,
conf: Optional[Dict[str, Any]] = None,
auth_token: Optional[str] = None,
generation: Optional[int] = None,
) -> TenantId:
@@ -2658,7 +2662,7 @@ class PgBin:
self.env = os.environ.copy()
self.env["LD_LIBRARY_PATH"] = str(self.pg_lib_dir)
def _fixpath(self, command: list[str]):
def _fixpath(self, command: List[str]):
if "/" not in str(command[0]):
command[0] = str(self.pg_bin_path / command[0])
@@ -2678,7 +2682,7 @@ class PgBin:
def run_nonblocking(
self,
command: list[str],
command: List[str],
env: Optional[Env] = None,
cwd: Optional[Union[str, Path]] = None,
) -> subprocess.Popen[Any]:
@@ -2702,7 +2706,7 @@ class PgBin:
def run(
self,
command: list[str],
command: List[str],
env: Optional[Env] = None,
cwd: Optional[Union[str, Path]] = None,
) -> None:
@@ -2725,7 +2729,7 @@ class PgBin:
def run_capture(
self,
command: list[str],
command: List[str],
env: Optional[Env] = None,
cwd: Optional[str] = None,
with_command_header=True,
@@ -2838,14 +2842,14 @@ class VanillaPostgres(PgProtocol):
]
)
def configure(self, options: list[str]):
def configure(self, options: List[str]):
"""Append lines into postgresql.conf file."""
assert not self.running
with open(os.path.join(self.pgdatadir, "postgresql.conf"), "a") as conf_file:
conf_file.write("\n".join(options))
conf_file.write("\n")
def edit_hba(self, hba: list[str]):
def edit_hba(self, hba: List[str]):
"""Prepend hba lines into pg_hba.conf file."""
assert not self.running
with open(os.path.join(self.pgdatadir, "pg_hba.conf"), "r+") as conf_file:
@@ -2873,12 +2877,12 @@ class VanillaPostgres(PgProtocol):
"""Return size of pgdatadir subdirectory in bytes."""
return get_dir_size(self.pgdatadir / subdir)
def __enter__(self) -> VanillaPostgres:
def __enter__(self) -> "VanillaPostgres":
return self
def __exit__(
self,
exc_type: Optional[type[BaseException]],
exc_type: Optional[Type[BaseException]],
exc: Optional[BaseException],
tb: Optional[TracebackType],
):
@@ -2908,7 +2912,7 @@ class RemotePostgres(PgProtocol):
# The remote server is assumed to be running already
self.running = True
def configure(self, options: list[str]):
def configure(self, options: List[str]):
raise Exception("cannot change configuration of remote Posgres instance")
def start(self):
@@ -2922,12 +2926,12 @@ class RemotePostgres(PgProtocol):
# See https://www.postgresql.org/docs/14/functions-admin.html#FUNCTIONS-ADMIN-GENFILE
raise Exception("cannot get size of a Postgres instance")
def __enter__(self) -> RemotePostgres:
def __enter__(self) -> "RemotePostgres":
return self
def __exit__(
self,
exc_type: Optional[type[BaseException]],
exc_type: Optional[Type[BaseException]],
exc: Optional[BaseException],
tb: Optional[TracebackType],
):
@@ -3263,7 +3267,7 @@ class NeonProxy(PgProtocol):
def __exit__(
self,
exc_type: Optional[type[BaseException]],
exc_type: Optional[Type[BaseException]],
exc: Optional[BaseException],
tb: Optional[TracebackType],
):
@@ -3401,7 +3405,7 @@ class Endpoint(PgProtocol, LogUtils):
self.http_port = http_port
self.check_stop_result = check_stop_result
# passed to endpoint create and endpoint reconfigure
self.active_safekeepers: list[int] = list(map(lambda sk: sk.id, env.safekeepers))
self.active_safekeepers: List[int] = list(map(lambda sk: sk.id, env.safekeepers))
# path to conf is <repo_dir>/endpoints/<endpoint_id>/pgdata/postgresql.conf
# Semaphore is set to 1 when we start, and acquire'd back to zero when we stop
@@ -3424,10 +3428,10 @@ class Endpoint(PgProtocol, LogUtils):
endpoint_id: Optional[str] = None,
hot_standby: bool = False,
lsn: Optional[Lsn] = None,
config_lines: Optional[list[str]] = None,
config_lines: Optional[List[str]] = None,
pageserver_id: Optional[int] = None,
allow_multiple: bool = False,
) -> Endpoint:
) -> "Endpoint":
"""
Create a new Postgres endpoint.
Returns self.
@@ -3470,10 +3474,10 @@ class Endpoint(PgProtocol, LogUtils):
self,
remote_ext_config: Optional[str] = None,
pageserver_id: Optional[int] = None,
safekeepers: Optional[list[int]] = None,
safekeepers: Optional[List[int]] = None,
allow_multiple: bool = False,
basebackup_request_tries: Optional[int] = None,
) -> Endpoint:
) -> "Endpoint":
"""
Start the Postgres instance.
Returns self.
@@ -3486,6 +3490,8 @@ class Endpoint(PgProtocol, LogUtils):
if safekeepers is not None:
self.active_safekeepers = safekeepers
log.info(f"Starting postgres endpoint {self.endpoint_id}")
self.env.neon_cli.endpoint_start(
self.endpoint_id,
safekeepers=self.active_safekeepers,
@@ -3520,7 +3526,7 @@ class Endpoint(PgProtocol, LogUtils):
"""Path to the postgresql.conf in the endpoint directory (not the one in pgdata)"""
return self.endpoint_path() / "postgresql.conf"
def config(self, lines: list[str]) -> Endpoint:
def config(self, lines: List[str]) -> "Endpoint":
"""
Add lines to postgresql.conf.
Lines should be an array of valid postgresql.conf rows.
@@ -3534,7 +3540,7 @@ class Endpoint(PgProtocol, LogUtils):
return self
def edit_hba(self, hba: list[str]):
def edit_hba(self, hba: List[str]):
"""Prepend hba lines into pg_hba.conf file."""
with open(os.path.join(self.pg_data_dir_path(), "pg_hba.conf"), "r+") as conf_file:
data = conf_file.read()
@@ -3549,7 +3555,7 @@ class Endpoint(PgProtocol, LogUtils):
return self._running._value > 0
def reconfigure(
self, pageserver_id: Optional[int] = None, safekeepers: Optional[list[int]] = None
self, pageserver_id: Optional[int] = None, safekeepers: Optional[List[int]] = None
):
assert self.endpoint_id is not None
# If `safekeepers` is not None, they are remember them as active and use
@@ -3564,7 +3570,7 @@ class Endpoint(PgProtocol, LogUtils):
"""Update the endpoint.json file used by control_plane."""
# Read config
config_path = os.path.join(self.endpoint_path(), "endpoint.json")
with open(config_path) as f:
with open(config_path, "r") as f:
data_dict: dict[str, Any] = json.load(f)
# Write it back updated
@@ -3597,8 +3603,8 @@ class Endpoint(PgProtocol, LogUtils):
def stop(
self,
mode: str = "fast",
sks_wait_walreceiver_gone: Optional[tuple[list[Safekeeper], TimelineId]] = None,
) -> Endpoint:
sks_wait_walreceiver_gone: Optional[tuple[List[Safekeeper], TimelineId]] = None,
) -> "Endpoint":
"""
Stop the Postgres instance if it's running.
@@ -3632,7 +3638,7 @@ class Endpoint(PgProtocol, LogUtils):
return self
def stop_and_destroy(self, mode: str = "immediate") -> Endpoint:
def stop_and_destroy(self, mode: str = "immediate") -> "Endpoint":
"""
Stop the Postgres instance, then destroy the endpoint.
Returns self.
@@ -3654,17 +3660,19 @@ class Endpoint(PgProtocol, LogUtils):
endpoint_id: Optional[str] = None,
hot_standby: bool = False,
lsn: Optional[Lsn] = None,
config_lines: Optional[list[str]] = None,
config_lines: Optional[List[str]] = None,
remote_ext_config: Optional[str] = None,
pageserver_id: Optional[int] = None,
allow_multiple: bool = False,
allow_multiple=False,
basebackup_request_tries: Optional[int] = None,
) -> Endpoint:
) -> "Endpoint":
"""
Create an endpoint, apply config, and start Postgres.
Returns self.
"""
started_at = time.time()
self.create(
branch_name=branch_name,
endpoint_id=endpoint_id,
@@ -3680,14 +3688,16 @@ class Endpoint(PgProtocol, LogUtils):
basebackup_request_tries=basebackup_request_tries,
)
log.info(f"Postgres startup took {time.time() - started_at} seconds")
return self
def __enter__(self) -> Endpoint:
def __enter__(self) -> "Endpoint":
return self
def __exit__(
self,
exc_type: Optional[type[BaseException]],
exc_type: Optional[Type[BaseException]],
exc: Optional[BaseException],
tb: Optional[TracebackType],
):
@@ -3718,7 +3728,7 @@ class EndpointFactory:
def __init__(self, env: NeonEnv):
self.env = env
self.num_instances: int = 0
self.endpoints: list[Endpoint] = []
self.endpoints: List[Endpoint] = []
def create_start(
self,
@@ -3727,7 +3737,7 @@ class EndpointFactory:
tenant_id: Optional[TenantId] = None,
lsn: Optional[Lsn] = None,
hot_standby: bool = False,
config_lines: Optional[list[str]] = None,
config_lines: Optional[List[str]] = None,
remote_ext_config: Optional[str] = None,
pageserver_id: Optional[int] = None,
basebackup_request_tries: Optional[int] = None,
@@ -3759,7 +3769,7 @@ class EndpointFactory:
tenant_id: Optional[TenantId] = None,
lsn: Optional[Lsn] = None,
hot_standby: bool = False,
config_lines: Optional[list[str]] = None,
config_lines: Optional[List[str]] = None,
pageserver_id: Optional[int] = None,
) -> Endpoint:
ep = Endpoint(
@@ -3783,7 +3793,7 @@ class EndpointFactory:
pageserver_id=pageserver_id,
)
def stop_all(self, fail_on_error=True) -> EndpointFactory:
def stop_all(self, fail_on_error=True) -> "EndpointFactory":
exception = None
for ep in self.endpoints:
try:
@@ -3798,7 +3808,7 @@ class EndpointFactory:
return self
def new_replica(
self, origin: Endpoint, endpoint_id: str, config_lines: Optional[list[str]] = None
self, origin: Endpoint, endpoint_id: str, config_lines: Optional[List[str]] = None
):
branch_name = origin.branch_name
assert origin in self.endpoints
@@ -3814,7 +3824,7 @@ class EndpointFactory:
)
def new_replica_start(
self, origin: Endpoint, endpoint_id: str, config_lines: Optional[list[str]] = None
self, origin: Endpoint, endpoint_id: str, config_lines: Optional[List[str]] = None
):
branch_name = origin.branch_name
assert origin in self.endpoints
@@ -3852,7 +3862,7 @@ class Safekeeper(LogUtils):
port: SafekeeperPort,
id: int,
running: bool = False,
extra_opts: Optional[list[str]] = None,
extra_opts: Optional[List[str]] = None,
):
self.env = env
self.port = port
@@ -3878,8 +3888,8 @@ class Safekeeper(LogUtils):
self.extra_opts = extra_opts
def start(
self, extra_opts: Optional[list[str]] = None, timeout_in_seconds: Optional[int] = None
) -> Safekeeper:
self, extra_opts: Optional[List[str]] = None, timeout_in_seconds: Optional[int] = None
) -> "Safekeeper":
if extra_opts is None:
# Apply either the extra_opts passed in, or the ones from our constructor: we do not merge the two.
extra_opts = self.extra_opts
@@ -3914,7 +3924,8 @@ class Safekeeper(LogUtils):
break # success
return self
def stop(self, immediate: bool = False) -> Safekeeper:
def stop(self, immediate: bool = False) -> "Safekeeper":
log.info(f"Stopping safekeeper {self.id}")
self.env.neon_cli.safekeeper_stop(self.id, immediate)
self.running = False
return self
@@ -3925,8 +3936,8 @@ class Safekeeper(LogUtils):
assert not self.log_contains("timeout while acquiring WalResidentTimeline guard")
def append_logical_message(
self, tenant_id: TenantId, timeline_id: TimelineId, request: dict[str, Any]
) -> dict[str, Any]:
self, tenant_id: TenantId, timeline_id: TimelineId, request: Dict[str, Any]
) -> Dict[str, Any]:
"""
Send JSON_CTRL query to append LogicalMessage to WAL and modify
safekeeper state. It will construct LogicalMessage from provided
@@ -3979,7 +3990,7 @@ class Safekeeper(LogUtils):
def pull_timeline(
self, srcs: list[Safekeeper], tenant_id: TenantId, timeline_id: TimelineId
) -> dict[str, Any]:
) -> Dict[str, Any]:
"""
pull_timeline from srcs to self.
"""
@@ -3998,7 +4009,7 @@ class Safekeeper(LogUtils):
def timeline_dir(self, tenant_id, timeline_id) -> Path:
return self.data_dir / str(tenant_id) / str(timeline_id)
# list partial uploaded segments of this safekeeper. Works only for
# List partial uploaded segments of this safekeeper. Works only for
# RemoteStorageKind.LOCAL_FS.
def list_uploaded_segments(self, tenant_id: TenantId, timeline_id: TimelineId):
tline_path = (
@@ -4015,7 +4026,7 @@ class Safekeeper(LogUtils):
mysegs = [s for s in segs if f"sk{self.id}" in s]
return mysegs
def list_segments(self, tenant_id, timeline_id) -> list[str]:
def list_segments(self, tenant_id, timeline_id) -> List[str]:
"""
Get list of segment names of the given timeline.
"""
@@ -4120,7 +4131,7 @@ class StorageScrubber:
self.log_dir = log_dir
def scrubber_cli(
self, args: list[str], timeout, extra_env: Optional[dict[str, str]] = None
self, args: list[str], timeout, extra_env: Optional[Dict[str, str]] = None
) -> str:
assert isinstance(self.env.pageserver_remote_storage, S3Storage)
s3_storage = self.env.pageserver_remote_storage
@@ -4167,10 +4178,10 @@ class StorageScrubber:
def scan_metadata_safekeeper(
self,
timeline_lsns: list[dict[str, Any]],
timeline_lsns: List[Dict[str, Any]],
cloud_admin_api_url: str,
cloud_admin_api_token: str,
) -> tuple[bool, Any]:
) -> Tuple[bool, Any]:
extra_env = {
"CLOUD_ADMIN_API_URL": cloud_admin_api_url,
"CLOUD_ADMIN_API_TOKEN": cloud_admin_api_token,
@@ -4183,9 +4194,9 @@ class StorageScrubber:
self,
post_to_storage_controller: bool = False,
node_kind: NodeKind = NodeKind.PAGESERVER,
timeline_lsns: Optional[list[dict[str, Any]]] = None,
extra_env: Optional[dict[str, str]] = None,
) -> tuple[bool, Any]:
timeline_lsns: Optional[List[Dict[str, Any]]] = None,
extra_env: Optional[Dict[str, str]] = None,
) -> Tuple[bool, Any]:
"""
Returns the health status and the metadata summary.
"""
@@ -4293,7 +4304,7 @@ def pytest_addoption(parser: Parser):
)
SMALL_DB_FILE_NAME_REGEX: re.Pattern[str] = re.compile(
SMALL_DB_FILE_NAME_REGEX: re.Pattern = re.compile( # type: ignore[type-arg]
r"config-v1|heatmap-v1|metadata|.+\.(?:toml|pid|json|sql|conf)"
)
@@ -4492,7 +4503,7 @@ def should_skip_file(filename: str) -> bool:
#
# Test helpers
#
def list_files_to_compare(pgdata_dir: Path) -> list[str]:
def list_files_to_compare(pgdata_dir: Path) -> List[str]:
pgdata_files = []
for root, _dirs, filenames in os.walk(pgdata_dir):
for filename in filenames:

View File

@@ -1,13 +1,8 @@
from __future__ import annotations
from pathlib import Path
from typing import TYPE_CHECKING
from typing import Iterator
import psutil
if TYPE_CHECKING:
from collections.abc import Iterator
def iter_mounts_beneath(topdir: Path) -> Iterator[Path]:
"""

View File

@@ -1 +0,0 @@
from __future__ import annotations

View File

@@ -1,16 +1,14 @@
#! /usr/bin/env python3
from __future__ import annotations
import argparse
import re
import sys
from collections.abc import Iterable
from typing import Iterable, List, Tuple
def scan_pageserver_log_for_errors(
input: Iterable[str], allowed_errors: list[str]
) -> list[tuple[int, str]]:
input: Iterable[str], allowed_errors: List[str]
) -> List[Tuple[int, str]]:
error_or_warn = re.compile(r"\s(ERROR|WARN)")
errors = []
for lineno, line in enumerate(input, start=1):
@@ -115,7 +113,7 @@ DEFAULT_STORAGE_CONTROLLER_ALLOWED_ERRORS = [
def _check_allowed_errors(input):
allowed_errors: list[str] = list(DEFAULT_PAGESERVER_ALLOWED_ERRORS)
allowed_errors: List[str] = list(DEFAULT_PAGESERVER_ALLOWED_ERRORS)
# add any test specifics here; cli parsing is not provided for the
# difficulty of copypasting regexes as arguments without any quoting

View File

@@ -1,14 +1,9 @@
from __future__ import annotations
import re
from dataclasses import dataclass
from typing import TYPE_CHECKING, Union
from typing import Any, Dict, Tuple, Union
from fixtures.common_types import KEY_MAX, KEY_MIN, Key, Lsn
if TYPE_CHECKING:
from typing import Any
@dataclass
class IndexLayerMetadata:
@@ -58,7 +53,7 @@ IMAGE_LAYER_FILE_NAME = re.compile(
)
def parse_image_layer(f_name: str) -> tuple[int, int, int]:
def parse_image_layer(f_name: str) -> Tuple[int, int, int]:
"""Parse an image layer file name. Return key start, key end, and snapshot lsn"""
match = IMAGE_LAYER_FILE_NAME.match(f_name)
@@ -73,7 +68,7 @@ DELTA_LAYER_FILE_NAME = re.compile(
)
def parse_delta_layer(f_name: str) -> tuple[int, int, int, int]:
def parse_delta_layer(f_name: str) -> Tuple[int, int, int, int]:
"""Parse a delta layer file name. Return key start, key end, lsn start, and lsn end"""
match = DELTA_LAYER_FILE_NAME.match(f_name)
if match is None:
@@ -126,11 +121,11 @@ def is_future_layer(layer_file_name: LayerName, disk_consistent_lsn: Lsn):
@dataclass
class IndexPartDump:
layer_metadata: dict[LayerName, IndexLayerMetadata]
layer_metadata: Dict[LayerName, IndexLayerMetadata]
disk_consistent_lsn: Lsn
@classmethod
def from_json(cls, d: dict[str, Any]) -> IndexPartDump:
def from_json(cls, d: Dict[str, Any]) -> "IndexPartDump":
return IndexPartDump(
layer_metadata={
parse_layer_file_name(n): IndexLayerMetadata(v["file_size"], v["generation"])

View File

@@ -4,7 +4,7 @@ import time
from collections import defaultdict
from dataclasses import dataclass
from datetime import datetime
from typing import TYPE_CHECKING, Any
from typing import Any, Dict, List, Optional, Set, Tuple, Union
import requests
from requests.adapters import HTTPAdapter
@@ -16,9 +16,6 @@ from fixtures.metrics import Metrics, MetricsGetter, parse_metrics
from fixtures.pg_version import PgVersion
from fixtures.utils import Fn
if TYPE_CHECKING:
from typing import Optional, Union
class PageserverApiException(Exception):
def __init__(self, message, status_code: int):
@@ -46,7 +43,7 @@ class InMemoryLayerInfo:
lsn_end: Optional[str]
@classmethod
def from_json(cls, d: dict[str, Any]) -> InMemoryLayerInfo:
def from_json(cls, d: Dict[str, Any]) -> InMemoryLayerInfo:
return InMemoryLayerInfo(
kind=d["kind"],
lsn_start=d["lsn_start"],
@@ -67,7 +64,7 @@ class HistoricLayerInfo:
visible: bool
@classmethod
def from_json(cls, d: dict[str, Any]) -> HistoricLayerInfo:
def from_json(cls, d: Dict[str, Any]) -> HistoricLayerInfo:
# instead of parsing the key range lets keep the definition of "L0" in pageserver
l0_ness = d.get("l0")
assert l0_ness is None or isinstance(l0_ness, bool)
@@ -89,53 +86,53 @@ class HistoricLayerInfo:
@dataclass
class LayerMapInfo:
in_memory_layers: list[InMemoryLayerInfo]
historic_layers: list[HistoricLayerInfo]
in_memory_layers: List[InMemoryLayerInfo]
historic_layers: List[HistoricLayerInfo]
@classmethod
def from_json(cls, d: dict[str, Any]) -> LayerMapInfo:
def from_json(cls, d: Dict[str, Any]) -> LayerMapInfo:
info = LayerMapInfo(in_memory_layers=[], historic_layers=[])
json_in_memory_layers = d["in_memory_layers"]
assert isinstance(json_in_memory_layers, list)
assert isinstance(json_in_memory_layers, List)
for json_in_memory_layer in json_in_memory_layers:
info.in_memory_layers.append(InMemoryLayerInfo.from_json(json_in_memory_layer))
json_historic_layers = d["historic_layers"]
assert isinstance(json_historic_layers, list)
assert isinstance(json_historic_layers, List)
for json_historic_layer in json_historic_layers:
info.historic_layers.append(HistoricLayerInfo.from_json(json_historic_layer))
return info
def kind_count(self) -> dict[str, int]:
counts: dict[str, int] = defaultdict(int)
def kind_count(self) -> Dict[str, int]:
counts: Dict[str, int] = defaultdict(int)
for inmem_layer in self.in_memory_layers:
counts[inmem_layer.kind] += 1
for hist_layer in self.historic_layers:
counts[hist_layer.kind] += 1
return counts
def delta_layers(self) -> list[HistoricLayerInfo]:
def delta_layers(self) -> List[HistoricLayerInfo]:
return [x for x in self.historic_layers if x.kind == "Delta"]
def image_layers(self) -> list[HistoricLayerInfo]:
def image_layers(self) -> List[HistoricLayerInfo]:
return [x for x in self.historic_layers if x.kind == "Image"]
def delta_l0_layers(self) -> list[HistoricLayerInfo]:
def delta_l0_layers(self) -> List[HistoricLayerInfo]:
return [x for x in self.historic_layers if x.kind == "Delta" and x.l0]
def historic_by_name(self) -> set[str]:
def historic_by_name(self) -> Set[str]:
return set(x.layer_file_name for x in self.historic_layers)
@dataclass
class TenantConfig:
tenant_specific_overrides: dict[str, Any]
effective_config: dict[str, Any]
tenant_specific_overrides: Dict[str, Any]
effective_config: Dict[str, Any]
@classmethod
def from_json(cls, d: dict[str, Any]) -> TenantConfig:
def from_json(cls, d: Dict[str, Any]) -> TenantConfig:
return TenantConfig(
tenant_specific_overrides=d["tenant_specific_overrides"],
effective_config=d["effective_config"],
@@ -212,7 +209,7 @@ class PageserverHttpClient(requests.Session, MetricsGetter):
def check_status(self):
self.get(f"http://localhost:{self.port}/v1/status").raise_for_status()
def configure_failpoints(self, config_strings: tuple[str, str] | list[tuple[str, str]]):
def configure_failpoints(self, config_strings: Tuple[str, str] | List[Tuple[str, str]]):
self.is_testing_enabled_or_skip()
if isinstance(config_strings, tuple):
@@ -236,7 +233,7 @@ class PageserverHttpClient(requests.Session, MetricsGetter):
res = self.post(f"http://localhost:{self.port}/v1/reload_auth_validation_keys")
self.verbose_error(res)
def tenant_list(self) -> list[dict[Any, Any]]:
def tenant_list(self) -> List[Dict[Any, Any]]:
res = self.get(f"http://localhost:{self.port}/v1/tenant")
self.verbose_error(res)
res_json = res.json()
@@ -247,7 +244,7 @@ class PageserverHttpClient(requests.Session, MetricsGetter):
self,
tenant_id: Union[TenantId, TenantShardId],
generation: int,
config: None | dict[str, Any] = None,
config: None | Dict[str, Any] = None,
):
config = config or {}
@@ -327,7 +324,7 @@ class PageserverHttpClient(requests.Session, MetricsGetter):
def tenant_status(
self, tenant_id: Union[TenantId, TenantShardId], activate: bool = False
) -> dict[Any, Any]:
) -> Dict[Any, Any]:
"""
:activate: hint the server not to accelerate activation of this tenant in response
to this query. False by default for tests, because they generally want to observed the
@@ -381,8 +378,8 @@ class PageserverHttpClient(requests.Session, MetricsGetter):
def patch_tenant_config_client_side(
self,
tenant_id: TenantId,
inserts: Optional[dict[str, Any]] = None,
removes: Optional[list[str]] = None,
inserts: Optional[Dict[str, Any]] = None,
removes: Optional[List[str]] = None,
):
current = self.tenant_config(tenant_id).tenant_specific_overrides
if inserts is not None:
@@ -397,7 +394,7 @@ class PageserverHttpClient(requests.Session, MetricsGetter):
def tenant_size_and_modelinputs(
self, tenant_id: Union[TenantId, TenantShardId]
) -> tuple[int, dict[str, Any]]:
) -> Tuple[int, Dict[str, Any]]:
"""
Returns the tenant size, together with the model inputs as the second tuple item.
"""
@@ -427,7 +424,7 @@ class PageserverHttpClient(requests.Session, MetricsGetter):
tenant_id: Union[TenantId, TenantShardId],
timestamp: datetime,
done_if_after: datetime,
shard_counts: Optional[list[int]] = None,
shard_counts: Optional[List[int]] = None,
):
"""
Issues a request to perform time travel operations on the remote storage
@@ -435,7 +432,7 @@ class PageserverHttpClient(requests.Session, MetricsGetter):
if shard_counts is None:
shard_counts = []
body: dict[str, Any] = {
body: Dict[str, Any] = {
"shard_counts": shard_counts,
}
res = self.put(
@@ -449,7 +446,7 @@ class PageserverHttpClient(requests.Session, MetricsGetter):
tenant_id: Union[TenantId, TenantShardId],
include_non_incremental_logical_size: bool = False,
include_timeline_dir_layer_file_size_sum: bool = False,
) -> list[dict[str, Any]]:
) -> List[Dict[str, Any]]:
params = {}
if include_non_incremental_logical_size:
params["include-non-incremental-logical-size"] = "true"
@@ -473,8 +470,8 @@ class PageserverHttpClient(requests.Session, MetricsGetter):
ancestor_start_lsn: Optional[Lsn] = None,
existing_initdb_timeline_id: Optional[TimelineId] = None,
**kwargs,
) -> dict[Any, Any]:
body: dict[str, Any] = {
) -> Dict[Any, Any]:
body: Dict[str, Any] = {
"new_timeline_id": str(new_timeline_id),
"ancestor_start_lsn": str(ancestor_start_lsn) if ancestor_start_lsn else None,
"ancestor_timeline_id": str(ancestor_timeline_id) if ancestor_timeline_id else None,
@@ -507,7 +504,7 @@ class PageserverHttpClient(requests.Session, MetricsGetter):
include_timeline_dir_layer_file_size_sum: bool = False,
force_await_initial_logical_size: bool = False,
**kwargs,
) -> dict[Any, Any]:
) -> Dict[Any, Any]:
params = {}
if include_non_incremental_logical_size:
params["include-non-incremental-logical-size"] = "true"
@@ -847,7 +844,7 @@ class PageserverHttpClient(requests.Session, MetricsGetter):
)
if len(res) != 2:
return None
inc, dec = (res[metric] for metric in metrics)
inc, dec = [res[metric] for metric in metrics]
queue_count = int(inc) - int(dec)
assert queue_count >= 0
return queue_count
@@ -888,7 +885,7 @@ class PageserverHttpClient(requests.Session, MetricsGetter):
timeline_id: TimelineId,
batch_size: int | None = None,
**kwargs,
) -> set[TimelineId]:
) -> Set[TimelineId]:
params = {}
if batch_size is not None:
params["batch_size"] = batch_size

View File

@@ -1,7 +1,5 @@
from __future__ import annotations
import concurrent.futures
from typing import TYPE_CHECKING
from typing import Any, Callable, Dict, Tuple
import fixtures.pageserver.remote_storage
from fixtures.common_types import TenantId, TimelineId
@@ -12,13 +10,10 @@ from fixtures.neon_fixtures import (
)
from fixtures.remote_storage import LocalFsStorage, RemoteStorageKind
if TYPE_CHECKING:
from typing import Any, Callable
def single_timeline(
neon_env_builder: NeonEnvBuilder,
setup_template: Callable[[NeonEnv], tuple[TenantId, TimelineId, dict[str, Any]]],
setup_template: Callable[[NeonEnv], Tuple[TenantId, TimelineId, Dict[str, Any]]],
ncopies: int,
) -> NeonEnv:
"""

View File

@@ -1,12 +1,10 @@
from __future__ import annotations
import concurrent.futures
import os
import queue
import shutil
import threading
from pathlib import Path
from typing import TYPE_CHECKING
from typing import Any, List, Tuple
from fixtures.common_types import TenantId, TimelineId
from fixtures.neon_fixtures import NeonEnv
@@ -16,9 +14,6 @@ from fixtures.pageserver.common_types import (
)
from fixtures.remote_storage import LocalFsStorage
if TYPE_CHECKING:
from typing import Any
def duplicate_one_tenant(env: NeonEnv, template_tenant: TenantId, new_tenant: TenantId):
remote_storage = env.pageserver_remote_storage
@@ -55,13 +50,13 @@ def duplicate_one_tenant(env: NeonEnv, template_tenant: TenantId, new_tenant: Te
return None
def duplicate_tenant(env: NeonEnv, template_tenant: TenantId, ncopies: int) -> list[TenantId]:
def duplicate_tenant(env: NeonEnv, template_tenant: TenantId, ncopies: int) -> List[TenantId]:
assert isinstance(env.pageserver_remote_storage, LocalFsStorage)
def work(tenant_id):
duplicate_one_tenant(env, template_tenant, tenant_id)
new_tenants: list[TenantId] = [TenantId.generate() for _ in range(0, ncopies)]
new_tenants: List[TenantId] = [TenantId.generate() for _ in range(0, ncopies)]
with concurrent.futures.ThreadPoolExecutor(max_workers=8) as executor:
executor.map(work, new_tenants)
return new_tenants
@@ -84,7 +79,7 @@ def local_layer_name_from_remote_name(remote_name: str) -> str:
def copy_all_remote_layer_files_to_local_tenant_dir(
env: NeonEnv, tenant_timelines: list[tuple[TenantId, TimelineId]]
env: NeonEnv, tenant_timelines: List[Tuple[TenantId, TimelineId]]
):
remote_storage = env.pageserver_remote_storage
assert isinstance(remote_storage, LocalFsStorage)

View File

@@ -1,7 +1,5 @@
from __future__ import annotations
import time
from typing import TYPE_CHECKING
from typing import Any, Dict, List, Optional, Tuple, Union
from mypy_boto3_s3.type_defs import (
DeleteObjectOutputTypeDef,
@@ -16,9 +14,6 @@ from fixtures.pageserver.http import PageserverApiException, PageserverHttpClien
from fixtures.remote_storage import RemoteStorage, RemoteStorageKind, S3Storage
from fixtures.utils import wait_until
if TYPE_CHECKING:
from typing import Any, Optional, Union
def assert_tenant_state(
pageserver_http: PageserverHttpClient,
@@ -71,7 +66,7 @@ def wait_for_upload(
)
def _tenant_in_expected_state(tenant_info: dict[str, Any], expected_state: str):
def _tenant_in_expected_state(tenant_info: Dict[str, Any], expected_state: str):
if tenant_info["state"]["slug"] == expected_state:
return True
if tenant_info["state"]["slug"] == "Broken":
@@ -85,7 +80,7 @@ def wait_until_tenant_state(
expected_state: str,
iterations: int,
period: float = 1.0,
) -> dict[str, Any]:
) -> Dict[str, Any]:
"""
Does not use `wait_until` for debugging purposes
"""
@@ -141,7 +136,7 @@ def wait_until_timeline_state(
expected_state: str,
iterations: int,
period: float = 1.0,
) -> dict[str, Any]:
) -> Dict[str, Any]:
"""
Does not use `wait_until` for debugging purposes
"""
@@ -152,7 +147,7 @@ def wait_until_timeline_state(
if isinstance(timeline["state"], str):
if timeline["state"] == expected_state:
return timeline
elif isinstance(timeline, dict):
elif isinstance(timeline, Dict):
if timeline["state"].get(expected_state):
return timeline
@@ -240,7 +235,7 @@ def wait_for_upload_queue_empty(
# this is `started left join finished`; if match, subtracting start from finished, resulting in queue depth
remaining_labels = ["shard_id", "file_kind", "op_kind"]
tl: list[tuple[Any, float]] = []
tl: List[Tuple[Any, float]] = []
for s in started:
found = False
for f in finished:
@@ -307,7 +302,7 @@ def assert_prefix_empty(
assert remote_storage is not None
response = list_prefix(remote_storage, prefix)
keys = response["KeyCount"]
objects: list[ObjectTypeDef] = response.get("Contents", [])
objects: List[ObjectTypeDef] = response.get("Contents", [])
common_prefixes = response.get("CommonPrefixes", [])
is_mock_s3 = isinstance(remote_storage, S3Storage) and not remote_storage.cleanup
@@ -435,7 +430,7 @@ def enable_remote_storage_versioning(
return response
def many_small_layers_tenant_config() -> dict[str, Any]:
def many_small_layers_tenant_config() -> Dict[str, Any]:
"""
Create a new dict to avoid issues with deleting from the global value.
In python, the global is mutable.

Some files were not shown because too many files have changed in this diff Show More