Compare commits

..

4 Commits

Author SHA1 Message Date
Tristan Partin
d3464584a6 Improve some typing in test_runner
Fixes some types, adds some types, and adds some override annotations.

Signed-off-by: Tristan Partin <tristan@neon.tech>
2024-10-09 15:42:22 -05:00
Tristan Partin
878135fe9c Move PgBenchInitResult.EXTRACTORS to a private module constant
This seems to paper over a behavioral difference in Python 3.9 and
Python 3.12 with how dataclasses work with mutable variables. On Python
3.12, I get the following error:

ValueError: mutable default <class 'dict'> for field EXTRACTORS is not allowed: use default_factory

This obviously doesn't occur in our testing environment. When I do what
the error tells me, EXTRACTORS doesn't seem to exist as an attribute on
the class in at least Python 3.9.

The solution provided in this commit seems like the least amount of
friction to keep the wheels turning.

Signed-off-by: Tristan Partin <tristan@neon.tech>
2024-10-09 14:02:09 -05:00
Conrad Ludgate
75434060a5 local_proxy: integrate with pg_session_jwt extension (#9086) 2024-10-09 18:24:10 +01:00
Anastasia Lubennikova
721803a0e7 Add partial support of extensions for v17: (#9322)
- PostGIS 3.5.0
- pgrouting 3.6.2
- h3 4.1.3
- unit 7.9
- pgjwt version (f3d82fd)
- pg_hashids 1.2.1
- ip4r 2.4.2
- prefix 1.2.10
- postgresql-hll 2.18
- pg_roaringbitmap 0.5.4
- pg-semver 0.40.0

update support of extensions for v14-v16:
- unit 7.7 -> 7.9
- pgjwt 9742dab -> f3d82fd

---------

Co-authored-by: Heikki Linnakangas <heikki@neon.tech>
2024-10-09 17:07:59 +01:00
35 changed files with 1089 additions and 238 deletions

18
Cargo.lock generated
View File

@@ -1820,6 +1820,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b5e6043086bf7973472e0c7dff2142ea0b680d30e18d9cc40f267efbf222bd47"
dependencies = [
"base16ct 0.2.0",
"base64ct",
"crypto-bigint 0.5.5",
"digest",
"ff 0.13.0",
@@ -1829,6 +1830,8 @@ dependencies = [
"pkcs8 0.10.2",
"rand_core 0.6.4",
"sec1 0.7.3",
"serde_json",
"serdect",
"subtle",
"zeroize",
]
@@ -4037,6 +4040,8 @@ dependencies = [
"bytes",
"fallible-iterator",
"postgres-protocol",
"serde",
"serde_json",
]
[[package]]
@@ -5256,6 +5261,7 @@ dependencies = [
"der 0.7.8",
"generic-array",
"pkcs8 0.10.2",
"serdect",
"subtle",
"zeroize",
]
@@ -5510,6 +5516,16 @@ dependencies = [
"syn 2.0.52",
]
[[package]]
name = "serdect"
version = "0.2.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "a84f14a19e9a014bb9f4512488d9829a68e04ecabffb0f9904cd1ace94598177"
dependencies = [
"base16ct 0.2.0",
"serde",
]
[[package]]
name = "sha1"
version = "0.10.5"
@@ -7302,6 +7318,7 @@ dependencies = [
"num-traits",
"once_cell",
"parquet",
"postgres-types",
"prettyplease",
"proc-macro2",
"prost",
@@ -7326,6 +7343,7 @@ dependencies = [
"time",
"time-macros",
"tokio",
"tokio-postgres",
"tokio-stream",
"tokio-util",
"toml_edit",

View File

@@ -109,13 +109,30 @@ RUN apt update && \
libcgal-dev libgdal-dev libgmp-dev libmpfr-dev libopenscenegraph-dev libprotobuf-c-dev \
protobuf-c-compiler xsltproc
# Postgis 3.5.0 requires SFCGAL 1.4+
#
# It would be nice to update all versions together, but we must solve the SFCGAL dependency first.
# SFCGAL > 1.3 requires CGAL > 5.2, Bullseye's libcgal-dev is 5.2
RUN case "${PG_VERSION}" in "v17") \
mkdir -p /sfcgal && \
echo "Postgis doensn't yet support PG17 (needs 3.4.3, if not higher)" && exit 0;; \
# and also we must check backward compatibility with older versions of PostGIS.
#
# Use new version only for v17
RUN case "${PG_VERSION}" in \
"v17") \
export SFCGAL_VERSION=1.4.1 \
export SFCGAL_CHECKSUM=1800c8a26241588f11cddcf433049e9b9aea902e923414d2ecef33a3295626c3 \
;; \
"v14" | "v15" | "v16") \
export SFCGAL_VERSION=1.3.10 \
export SFCGAL_CHECKSUM=4e39b3b2adada6254a7bdba6d297bb28e1a9835a9f879b74f37e2dab70203232 \
;; \
*) \
echo "unexpected PostgreSQL version" && exit 1 \
;; \
esac && \
wget https://gitlab.com/Oslandia/SFCGAL/-/archive/v1.3.10/SFCGAL-v1.3.10.tar.gz -O SFCGAL.tar.gz && \
echo "4e39b3b2adada6254a7bdba6d297bb28e1a9835a9f879b74f37e2dab70203232 SFCGAL.tar.gz" | sha256sum --check && \
mkdir -p /sfcgal && \
wget https://gitlab.com/sfcgal/SFCGAL/-/archive/v${SFCGAL_VERSION}/SFCGAL-v${SFCGAL_VERSION}.tar.gz -O SFCGAL.tar.gz && \
echo "${SFCGAL_CHECKSUM} SFCGAL.tar.gz" | sha256sum --check && \
mkdir sfcgal-src && cd sfcgal-src && tar xzf ../SFCGAL.tar.gz --strip-components=1 -C . && \
cmake -DCMAKE_BUILD_TYPE=Release . && make -j $(getconf _NPROCESSORS_ONLN) && \
DESTDIR=/sfcgal make install -j $(getconf _NPROCESSORS_ONLN) && \
@@ -123,15 +140,27 @@ RUN case "${PG_VERSION}" in "v17") \
ENV PATH="/usr/local/pgsql/bin:$PATH"
RUN case "${PG_VERSION}" in "v17") \
echo "Postgis doensn't yet support PG17 (needs 3.4.3, if not higher)" && exit 0;; \
# Postgis 3.5.0 supports v17
RUN case "${PG_VERSION}" in \
"v17") \
export POSTGIS_VERSION=3.5.0 \
export POSTGIS_CHECKSUM=ca698a22cc2b2b3467ac4e063b43a28413f3004ddd505bdccdd74c56a647f510 \
;; \
"v14" | "v15" | "v16") \
export POSTGIS_VERSION=3.3.3 \
export POSTGIS_CHECKSUM=74eb356e3f85f14233791013360881b6748f78081cc688ff9d6f0f673a762d13 \
;; \
*) \
echo "unexpected PostgreSQL version" && exit 1 \
;; \
esac && \
wget https://download.osgeo.org/postgis/source/postgis-3.3.3.tar.gz -O postgis.tar.gz && \
echo "74eb356e3f85f14233791013360881b6748f78081cc688ff9d6f0f673a762d13 postgis.tar.gz" | sha256sum --check && \
wget https://download.osgeo.org/postgis/source/postgis-${POSTGIS_VERSION}.tar.gz -O postgis.tar.gz && \
echo "${POSTGIS_CHECKSUM} postgis.tar.gz" | sha256sum --check && \
mkdir postgis-src && cd postgis-src && tar xzf ../postgis.tar.gz --strip-components=1 -C . && \
find /usr/local/pgsql -type f | sed 's|^/usr/local/pgsql/||' > /before.txt &&\
./autogen.sh && \
./configure --with-sfcgal=/usr/local/bin/sfcgal-config && \
make -j $(getconf _NPROCESSORS_ONLN) && \
make -j $(getconf _NPROCESSORS_ONLN) install && \
cd extensions/postgis && \
make clean && \
@@ -152,11 +181,27 @@ RUN case "${PG_VERSION}" in "v17") \
cp /usr/local/pgsql/share/extension/address_standardizer.control /extensions/postgis && \
cp /usr/local/pgsql/share/extension/address_standardizer_data_us.control /extensions/postgis
RUN case "${PG_VERSION}" in "v17") \
echo "v17 extensions are not supported yet. Quit" && exit 0;; \
# Uses versioned libraries, i.e. libpgrouting-3.4
# and may introduce function signature changes between releases
# i.e. release 3.5.0 has new signature for pg_dijkstra function
#
# Use new version only for v17
# last release v3.6.2 - Mar 30, 2024
RUN case "${PG_VERSION}" in \
"v17") \
export PGROUTING_VERSION=3.6.2 \
export PGROUTING_CHECKSUM=f4a1ed79d6f714e52548eca3bb8e5593c6745f1bde92eb5fb858efd8984dffa2 \
;; \
"v14" | "v15" | "v16") \
export PGROUTING_VERSION=3.4.2 \
export PGROUTING_CHECKSUM=cac297c07d34460887c4f3b522b35c470138760fe358e351ad1db4edb6ee306e \
;; \
*) \
echo "unexpected PostgreSQL version" && exit 1 \
;; \
esac && \
wget https://github.com/pgRouting/pgrouting/archive/v3.4.2.tar.gz -O pgrouting.tar.gz && \
echo "cac297c07d34460887c4f3b522b35c470138760fe358e351ad1db4edb6ee306e pgrouting.tar.gz" | sha256sum --check && \
wget https://github.com/pgRouting/pgrouting/archive/v${PGROUTING_VERSION}.tar.gz -O pgrouting.tar.gz && \
echo "${PGROUTING_CHECKSUM} pgrouting.tar.gz" | sha256sum --check && \
mkdir pgrouting-src && cd pgrouting-src && tar xzf ../pgrouting.tar.gz --strip-components=1 -C . && \
mkdir build && cd build && \
cmake -DCMAKE_BUILD_TYPE=Release .. && \
@@ -215,10 +260,9 @@ FROM build-deps AS h3-pg-build
ARG PG_VERSION
COPY --from=pg-build /usr/local/pgsql/ /usr/local/pgsql/
RUN case "${PG_VERSION}" in "v17") \
mkdir -p /h3/usr/ && \
echo "v17 extensions are not supported yet. Quit" && exit 0;; \
esac && \
# not version-specific
# last release v4.1.0 - Jan 18, 2023
RUN mkdir -p /h3/usr/ && \
wget https://github.com/uber/h3/archive/refs/tags/v4.1.0.tar.gz -O h3.tar.gz && \
echo "ec99f1f5974846bde64f4513cf8d2ea1b8d172d2218ab41803bf6a63532272bc h3.tar.gz" | sha256sum --check && \
mkdir h3-src && cd h3-src && tar xzf ../h3.tar.gz --strip-components=1 -C . && \
@@ -229,10 +273,9 @@ RUN case "${PG_VERSION}" in "v17") \
cp -R /h3/usr / && \
rm -rf build
RUN case "${PG_VERSION}" in "v17") \
echo "v17 extensions are not supported yet. Quit" && exit 0;; \
esac && \
wget https://github.com/zachasme/h3-pg/archive/refs/tags/v4.1.3.tar.gz -O h3-pg.tar.gz && \
# not version-specific
# last release v4.1.3 - Jul 26, 2023
RUN wget https://github.com/zachasme/h3-pg/archive/refs/tags/v4.1.3.tar.gz -O h3-pg.tar.gz && \
echo "5c17f09a820859ffe949f847bebf1be98511fb8f1bd86f94932512c00479e324 h3-pg.tar.gz" | sha256sum --check && \
mkdir h3-pg-src && cd h3-pg-src && tar xzf ../h3-pg.tar.gz --strip-components=1 -C . && \
export PATH="/usr/local/pgsql/bin:$PATH" && \
@@ -251,11 +294,10 @@ FROM build-deps AS unit-pg-build
ARG PG_VERSION
COPY --from=pg-build /usr/local/pgsql/ /usr/local/pgsql/
RUN case "${PG_VERSION}" in "v17") \
echo "v17 extensions are not supported yet. Quit" && exit 0;; \
esac && \
wget https://github.com/df7cb/postgresql-unit/archive/refs/tags/7.7.tar.gz -O postgresql-unit.tar.gz && \
echo "411d05beeb97e5a4abf17572bfcfbb5a68d98d1018918feff995f6ee3bb03e79 postgresql-unit.tar.gz" | sha256sum --check && \
# not version-specific
# last release 7.9 - Sep 15, 2024
RUN wget https://github.com/df7cb/postgresql-unit/archive/refs/tags/7.9.tar.gz -O postgresql-unit.tar.gz && \
echo "e46de6245dcc8b2c2ecf29873dbd43b2b346773f31dd5ce4b8315895a052b456 postgresql-unit.tar.gz" | sha256sum --check && \
mkdir postgresql-unit-src && cd postgresql-unit-src && tar xzf ../postgresql-unit.tar.gz --strip-components=1 -C . && \
make -j $(getconf _NPROCESSORS_ONLN) PG_CONFIG=/usr/local/pgsql/bin/pg_config && \
make -j $(getconf _NPROCESSORS_ONLN) install PG_CONFIG=/usr/local/pgsql/bin/pg_config && \
@@ -302,12 +344,10 @@ FROM build-deps AS pgjwt-pg-build
ARG PG_VERSION
COPY --from=pg-build /usr/local/pgsql/ /usr/local/pgsql/
# 9742dab1b2f297ad3811120db7b21451bca2d3c9 made on 13/11/2021
RUN case "${PG_VERSION}" in "v17") \
echo "v17 extensions are not supported yet. Quit" && exit 0;; \
esac && \
wget https://github.com/michelp/pgjwt/archive/9742dab1b2f297ad3811120db7b21451bca2d3c9.tar.gz -O pgjwt.tar.gz && \
echo "cfdefb15007286f67d3d45510f04a6a7a495004be5b3aecb12cda667e774203f pgjwt.tar.gz" | sha256sum --check && \
# not version-specific
# doesn't use releases, last commit f3d82fd - Mar 2, 2023
RUN wget https://github.com/michelp/pgjwt/archive/f3d82fd30151e754e19ce5d6a06c71c20689ce3d.tar.gz -O pgjwt.tar.gz && \
echo "dae8ed99eebb7593b43013f6532d772b12dfecd55548d2673f2dfd0163f6d2b9 pgjwt.tar.gz" | sha256sum --check && \
mkdir pgjwt-src && cd pgjwt-src && tar xzf ../pgjwt.tar.gz --strip-components=1 -C . && \
make -j $(getconf _NPROCESSORS_ONLN) install PG_CONFIG=/usr/local/pgsql/bin/pg_config && \
echo 'trusted = true' >> /usr/local/pgsql/share/extension/pgjwt.control
@@ -342,10 +382,9 @@ FROM build-deps AS pg-hashids-pg-build
ARG PG_VERSION
COPY --from=pg-build /usr/local/pgsql/ /usr/local/pgsql/
RUN case "${PG_VERSION}" in "v17") \
echo "v17 extensions are not supported yet. Quit" && exit 0;; \
esac && \
wget https://github.com/iCyberon/pg_hashids/archive/refs/tags/v1.2.1.tar.gz -O pg_hashids.tar.gz && \
# not version-specific
# last release v1.2.1 -Jan 12, 2018
RUN wget https://github.com/iCyberon/pg_hashids/archive/refs/tags/v1.2.1.tar.gz -O pg_hashids.tar.gz && \
echo "74576b992d9277c92196dd8d816baa2cc2d8046fe102f3dcd7f3c3febed6822a pg_hashids.tar.gz" | sha256sum --check && \
mkdir pg_hashids-src && cd pg_hashids-src && tar xzf ../pg_hashids.tar.gz --strip-components=1 -C . && \
make -j $(getconf _NPROCESSORS_ONLN) PG_CONFIG=/usr/local/pgsql/bin/pg_config USE_PGXS=1 && \
@@ -405,10 +444,9 @@ FROM build-deps AS ip4r-pg-build
ARG PG_VERSION
COPY --from=pg-build /usr/local/pgsql/ /usr/local/pgsql/
RUN case "${PG_VERSION}" in "v17") \
echo "v17 extensions are not supported yet. Quit" && exit 0;; \
esac && \
wget https://github.com/RhodiumToad/ip4r/archive/refs/tags/2.4.2.tar.gz -O ip4r.tar.gz && \
# not version-specific
# last release v2.4.2 - Jul 29, 2023
RUN wget https://github.com/RhodiumToad/ip4r/archive/refs/tags/2.4.2.tar.gz -O ip4r.tar.gz && \
echo "0f7b1f159974f49a47842a8ab6751aecca1ed1142b6d5e38d81b064b2ead1b4b ip4r.tar.gz" | sha256sum --check && \
mkdir ip4r-src && cd ip4r-src && tar xzf ../ip4r.tar.gz --strip-components=1 -C . && \
make -j $(getconf _NPROCESSORS_ONLN) PG_CONFIG=/usr/local/pgsql/bin/pg_config && \
@@ -425,10 +463,9 @@ FROM build-deps AS prefix-pg-build
ARG PG_VERSION
COPY --from=pg-build /usr/local/pgsql/ /usr/local/pgsql/
RUN case "${PG_VERSION}" in "v17") \
echo "v17 extensions are not supported yet. Quit" && exit 0;; \
esac && \
wget https://github.com/dimitri/prefix/archive/refs/tags/v1.2.10.tar.gz -O prefix.tar.gz && \
# not version-specific
# last release v1.2.10 - Jul 5, 2023
RUN wget https://github.com/dimitri/prefix/archive/refs/tags/v1.2.10.tar.gz -O prefix.tar.gz && \
echo "4342f251432a5f6fb05b8597139d3ccde8dcf87e8ca1498e7ee931ca057a8575 prefix.tar.gz" | sha256sum --check && \
mkdir prefix-src && cd prefix-src && tar xzf ../prefix.tar.gz --strip-components=1 -C . && \
make -j $(getconf _NPROCESSORS_ONLN) PG_CONFIG=/usr/local/pgsql/bin/pg_config && \
@@ -445,10 +482,9 @@ FROM build-deps AS hll-pg-build
ARG PG_VERSION
COPY --from=pg-build /usr/local/pgsql/ /usr/local/pgsql/
RUN case "${PG_VERSION}" in "v17") \
echo "v17 extensions are not supported yet. Quit" && exit 0;; \
esac && \
wget https://github.com/citusdata/postgresql-hll/archive/refs/tags/v2.18.tar.gz -O hll.tar.gz && \
# not version-specific
# last release v2.18 - Aug 29, 2023
RUN wget https://github.com/citusdata/postgresql-hll/archive/refs/tags/v2.18.tar.gz -O hll.tar.gz && \
echo "e2f55a6f4c4ab95ee4f1b4a2b73280258c5136b161fe9d059559556079694f0e hll.tar.gz" | sha256sum --check && \
mkdir hll-src && cd hll-src && tar xzf ../hll.tar.gz --strip-components=1 -C . && \
make -j $(getconf _NPROCESSORS_ONLN) PG_CONFIG=/usr/local/pgsql/bin/pg_config && \
@@ -659,11 +695,10 @@ FROM build-deps AS pg-roaringbitmap-pg-build
ARG PG_VERSION
COPY --from=pg-build /usr/local/pgsql/ /usr/local/pgsql/
# not version-specific
# last release v0.5.4 - Jun 28, 2022
ENV PATH="/usr/local/pgsql/bin/:$PATH"
RUN case "${PG_VERSION}" in "v17") \
echo "v17 extensions is not supported yet by pg_roaringbitmap. Quit" && exit 0;; \
esac && \
wget https://github.com/ChenHuajun/pg_roaringbitmap/archive/refs/tags/v0.5.4.tar.gz -O pg_roaringbitmap.tar.gz && \
RUN wget https://github.com/ChenHuajun/pg_roaringbitmap/archive/refs/tags/v0.5.4.tar.gz -O pg_roaringbitmap.tar.gz && \
echo "b75201efcb1c2d1b014ec4ae6a22769cc7a224e6e406a587f5784a37b6b5a2aa pg_roaringbitmap.tar.gz" | sha256sum --check && \
mkdir pg_roaringbitmap-src && cd pg_roaringbitmap-src && tar xzf ../pg_roaringbitmap.tar.gz --strip-components=1 -C . && \
make -j $(getconf _NPROCESSORS_ONLN) && \
@@ -680,12 +715,27 @@ FROM build-deps AS pg-semver-pg-build
ARG PG_VERSION
COPY --from=pg-build /usr/local/pgsql/ /usr/local/pgsql/
# Release 0.40.0 breaks backward compatibility with previous versions
# see release note https://github.com/theory/pg-semver/releases/tag/v0.40.0
# Use new version only for v17
#
# last release v0.40.0 - Jul 22, 2024
ENV PATH="/usr/local/pgsql/bin/:$PATH"
RUN case "${PG_VERSION}" in "v17") \
echo "v17 is not supported yet by pg_semver. Quit" && exit 0;; \
RUN case "${PG_VERSION}" in \
"v17") \
export SEMVER_VERSION=0.40.0 \
export SEMVER_CHECKSUM=3e50bcc29a0e2e481e7b6d2bc937cadc5f5869f55d983b5a1aafeb49f5425cfc \
;; \
"v14" | "v15" | "v16") \
export SEMVER_VERSION=0.32.1 \
export SEMVER_CHECKSUM=fbdaf7512026d62eec03fad8687c15ed509b6ba395bff140acd63d2e4fbe25d7 \
;; \
*) \
echo "unexpected PostgreSQL version" && exit 1 \
;; \
esac && \
wget https://github.com/theory/pg-semver/archive/refs/tags/v0.32.1.tar.gz -O pg_semver.tar.gz && \
echo "fbdaf7512026d62eec03fad8687c15ed509b6ba395bff140acd63d2e4fbe25d7 pg_semver.tar.gz" | sha256sum --check && \
wget https://github.com/theory/pg-semver/archive/refs/tags/v${SEMVER_VERSION}.tar.gz -O pg_semver.tar.gz && \
echo "${SEMVER_CHECKSUM} pg_semver.tar.gz" | sha256sum --check && \
mkdir pg_semver-src && cd pg_semver-src && tar xzf ../pg_semver.tar.gz --strip-components=1 -C . && \
make -j $(getconf _NPROCESSORS_ONLN) && \
make -j $(getconf _NPROCESSORS_ONLN) install && \

View File

@@ -77,7 +77,7 @@ subtle.workspace = true
thiserror.workspace = true
tikv-jemallocator.workspace = true
tikv-jemalloc-ctl = { workspace = true, features = ["use_std"] }
tokio-postgres.workspace = true
tokio-postgres = { workspace = true, features = ["with-serde_json-1"] }
tokio-postgres-rustls.workspace = true
tokio-rustls.workspace = true
tokio-util.workspace = true
@@ -101,7 +101,7 @@ jose-jwa = "0.1.2"
jose-jwk = { version = "0.1.2", features = ["p256", "p384", "rsa"] }
signature = "2"
ecdsa = "0.16"
p256 = "0.13"
p256 = { version = "0.13", features = ["jwk"] }
rsa = "0.9"
workspace_hack.workspace = true

View File

@@ -17,6 +17,8 @@ use crate::{
RoleName,
};
use super::ComputeCredentialKeys;
// TODO(conrad): make these configurable.
const CLOCK_SKEW_LEEWAY: Duration = Duration::from_secs(30);
const MIN_RENEW: Duration = Duration::from_secs(30);
@@ -241,7 +243,7 @@ impl JwkCacheEntryLock {
endpoint: EndpointId,
role_name: &RoleName,
fetch: &F,
) -> Result<(), anyhow::Error> {
) -> Result<ComputeCredentialKeys, anyhow::Error> {
// JWT compact form is defined to be
// <B64(Header)> || . || <B64(Payload)> || . || <B64(Signature)>
// where Signature = alg(<B64(Header)> || . || <B64(Payload)>);
@@ -300,9 +302,9 @@ impl JwkCacheEntryLock {
key => bail!("unsupported key type {key:?}"),
};
let payload = base64::decode_config(payload, base64::URL_SAFE_NO_PAD)
let payloadb = base64::decode_config(payload, base64::URL_SAFE_NO_PAD)
.context("Provided authentication token is not a valid JWT encoding")?;
let payload = serde_json::from_slice::<JwtPayload<'_>>(&payload)
let payload = serde_json::from_slice::<JwtPayload<'_>>(&payloadb)
.context("Provided authentication token is not a valid JWT encoding")?;
tracing::debug!(?payload, "JWT signature valid with claims");
@@ -327,7 +329,7 @@ impl JwkCacheEntryLock {
);
}
Ok(())
Ok(ComputeCredentialKeys::JwtPayload(payloadb))
}
}
@@ -339,7 +341,7 @@ impl JwkCache {
role_name: &RoleName,
fetch: &F,
jwt: &str,
) -> Result<(), anyhow::Error> {
) -> Result<ComputeCredentialKeys, anyhow::Error> {
// try with just a read lock first
let key = (endpoint.clone(), role_name.clone());
let entry = self.map.get(&key).as_deref().map(Arc::clone);

View File

@@ -175,10 +175,12 @@ impl ComputeUserInfo {
}
}
#[cfg_attr(test, derive(Debug))]
pub(crate) enum ComputeCredentialKeys {
#[cfg(any(test, feature = "testing"))]
Password(Vec<u8>),
AuthKeys(AuthKeys),
JwtPayload(Vec<u8>),
None,
}

View File

@@ -309,7 +309,7 @@ impl NodeInfo {
#[cfg(any(test, feature = "testing"))]
ComputeCredentialKeys::Password(password) => self.config.password(password),
ComputeCredentialKeys::AuthKeys(auth_keys) => self.config.auth_keys(*auth_keys),
ComputeCredentialKeys::None => &mut self.config,
ComputeCredentialKeys::JwtPayload(_) | ComputeCredentialKeys::None => &mut self.config,
};
}
}

View File

@@ -3,10 +3,12 @@ use std::{io, sync::Arc, time::Duration};
use async_trait::async_trait;
use hyper_util::rt::{TokioExecutor, TokioIo, TokioTimer};
use tokio::net::{lookup_host, TcpStream};
use tracing::{field::display, info};
use tokio_postgres::types::ToSql;
use tracing::{debug, field::display, info};
use crate::{
auth::{
self,
backend::{local::StaticAuthRules, ComputeCredentials, ComputeUserInfo},
check_peer_addr_is_in_list, AuthError,
},
@@ -32,10 +34,12 @@ use crate::{
use super::{
conn_pool::{poll_client, Client, ConnInfo, GlobalConnPool},
http_conn_pool::{self, poll_http2_client},
local_conn_pool::{self, LocalClient, LocalConnPool},
};
pub(crate) struct PoolingBackend {
pub(crate) http_conn_pool: Arc<super::http_conn_pool::GlobalConnPool>,
pub(crate) local_pool: Arc<LocalConnPool<tokio_postgres::Client>>,
pub(crate) pool: Arc<GlobalConnPool<tokio_postgres::Client>>,
pub(crate) config: &'static ProxyConfig,
pub(crate) endpoint_rate_limiter: Arc<EndpointRateLimiter>,
@@ -112,7 +116,7 @@ impl PoolingBackend {
config: &AuthenticationConfig,
user_info: &ComputeUserInfo,
jwt: String,
) -> Result<(), AuthError> {
) -> Result<ComputeCredentials, AuthError> {
match &self.config.auth_backend {
crate::auth::Backend::ControlPlane(console, ()) => {
config
@@ -127,13 +131,16 @@ impl PoolingBackend {
.await
.map_err(|e| AuthError::auth_failed(e.to_string()))?;
Ok(())
Ok(ComputeCredentials {
info: user_info.clone(),
keys: crate::auth::backend::ComputeCredentialKeys::None,
})
}
crate::auth::Backend::ConsoleRedirect(_, ()) => Err(AuthError::auth_failed(
"JWT login over web auth proxy is not supported",
)),
crate::auth::Backend::Local(_) => {
config
let keys = config
.jwks_cache
.check_jwt(
ctx,
@@ -145,8 +152,10 @@ impl PoolingBackend {
.await
.map_err(|e| AuthError::auth_failed(e.to_string()))?;
// todo: rewrite JWT signature with key shared somehow between local proxy and postgres
Ok(())
Ok(ComputeCredentials {
info: user_info.clone(),
keys,
})
}
}
}
@@ -231,6 +240,77 @@ impl PoolingBackend {
)
.await
}
/// Connect to postgres over localhost.
///
/// We expect postgres to be started here, so we won't do any retries.
///
/// # Panics
///
/// Panics if called with a non-local_proxy backend.
#[tracing::instrument(fields(pid = tracing::field::Empty), skip_all)]
pub(crate) async fn connect_to_local_postgres(
&self,
ctx: &RequestMonitoring,
conn_info: ConnInfo,
) -> Result<LocalClient<tokio_postgres::Client>, HttpConnError> {
if let Some(client) = self.local_pool.get(ctx, &conn_info)? {
return Ok(client);
}
let conn_id = uuid::Uuid::new_v4();
tracing::Span::current().record("conn_id", display(conn_id));
info!(%conn_id, "local_pool: opening a new connection '{conn_info}'");
let mut node_info = match &self.config.auth_backend {
auth::Backend::ControlPlane(_, ()) | auth::Backend::ConsoleRedirect(_, ()) => {
unreachable!("only local_proxy can connect to local postgres")
}
auth::Backend::Local(local) => local.node_info.clone(),
};
let config = node_info
.config
.user(&conn_info.user_info.user)
.dbname(&conn_info.dbname);
let pause = ctx.latency_timer_pause(crate::metrics::Waiting::Compute);
let (client, connection) = config.connect(tokio_postgres::NoTls).await?;
drop(pause);
tracing::Span::current().record("pid", tracing::field::display(client.get_process_id()));
let handle = local_conn_pool::poll_client(
self.local_pool.clone(),
ctx,
conn_info,
client,
connection,
conn_id,
node_info.aux.clone(),
);
let kid = handle.get_client().get_process_id() as i64;
let jwk = p256::PublicKey::from(handle.key().verifying_key()).to_jwk();
debug!(kid, ?jwk, "setting up backend session state");
// initiates the auth session
handle
.get_client()
.query(
"select auth.init($1, $2);",
&[
&kid as &(dyn ToSql + Sync),
&tokio_postgres::types::Json(jwk),
],
)
.await?;
info!(?kid, "backend session state init");
Ok(handle)
}
}
#[derive(Debug, thiserror::Error)]
@@ -241,6 +321,8 @@ pub(crate) enum HttpConnError {
PostgresConnectionError(#[from] tokio_postgres::Error),
#[error("could not connection to local-proxy in compute")]
LocalProxyConnectionError(#[from] LocalProxyConnError),
#[error("could not parse JWT payload")]
JwtPayloadError(serde_json::Error),
#[error("could not get auth info")]
GetAuthInfo(#[from] GetAuthInfoError),
@@ -266,6 +348,7 @@ impl ReportableError for HttpConnError {
HttpConnError::ConnectionClosedAbruptly(_) => ErrorKind::Compute,
HttpConnError::PostgresConnectionError(p) => p.get_error_kind(),
HttpConnError::LocalProxyConnectionError(_) => ErrorKind::Compute,
HttpConnError::JwtPayloadError(_) => ErrorKind::User,
HttpConnError::GetAuthInfo(a) => a.get_error_kind(),
HttpConnError::AuthError(a) => a.get_error_kind(),
HttpConnError::WakeCompute(w) => w.get_error_kind(),
@@ -280,6 +363,7 @@ impl UserFacingError for HttpConnError {
HttpConnError::ConnectionClosedAbruptly(_) => self.to_string(),
HttpConnError::PostgresConnectionError(p) => p.to_string(),
HttpConnError::LocalProxyConnectionError(p) => p.to_string(),
HttpConnError::JwtPayloadError(p) => p.to_string(),
HttpConnError::GetAuthInfo(c) => c.to_string_client(),
HttpConnError::AuthError(c) => c.to_string_client(),
HttpConnError::WakeCompute(c) => c.to_string_client(),
@@ -296,6 +380,7 @@ impl CouldRetry for HttpConnError {
HttpConnError::PostgresConnectionError(e) => e.could_retry(),
HttpConnError::LocalProxyConnectionError(e) => e.could_retry(),
HttpConnError::ConnectionClosedAbruptly(_) => false,
HttpConnError::JwtPayloadError(_) => false,
HttpConnError::GetAuthInfo(_) => false,
HttpConnError::AuthError(_) => false,
HttpConnError::WakeCompute(_) => false,

View File

@@ -0,0 +1,544 @@
use futures::{future::poll_fn, Future};
use jose_jwk::jose_b64::base64ct::{Base64UrlUnpadded, Encoding};
use p256::ecdsa::{Signature, SigningKey};
use parking_lot::RwLock;
use rand::rngs::OsRng;
use serde_json::Value;
use signature::Signer;
use std::task::{ready, Poll};
use std::{collections::HashMap, pin::pin, sync::Arc, sync::Weak, time::Duration};
use tokio::time::Instant;
use tokio_postgres::tls::NoTlsStream;
use tokio_postgres::types::ToSql;
use tokio_postgres::{AsyncMessage, ReadyForQueryStatus, Socket};
use tokio_util::sync::CancellationToken;
use typed_json::json;
use crate::control_plane::messages::{ColdStartInfo, MetricsAuxInfo};
use crate::metrics::Metrics;
use crate::usage_metrics::{Ids, MetricCounter, USAGE_METRICS};
use crate::{context::RequestMonitoring, DbName, RoleName};
use tracing::{debug, error, warn, Span};
use tracing::{info, info_span, Instrument};
use super::backend::HttpConnError;
use super::conn_pool::{ClientInnerExt, ConnInfo};
struct ConnPoolEntry<C: ClientInnerExt> {
conn: ClientInner<C>,
_last_access: std::time::Instant,
}
// /// key id for the pg_session_jwt state
// static PG_SESSION_JWT_KID: AtomicU64 = AtomicU64::new(1);
// Per-endpoint connection pool, (dbname, username) -> DbUserConnPool
// Number of open connections is limited by the `max_conns_per_endpoint`.
pub(crate) struct EndpointConnPool<C: ClientInnerExt> {
pools: HashMap<(DbName, RoleName), DbUserConnPool<C>>,
total_conns: usize,
max_conns: usize,
global_pool_size_max_conns: usize,
}
impl<C: ClientInnerExt> EndpointConnPool<C> {
fn get_conn_entry(&mut self, db_user: (DbName, RoleName)) -> Option<ConnPoolEntry<C>> {
let Self {
pools, total_conns, ..
} = self;
pools
.get_mut(&db_user)
.and_then(|pool_entries| pool_entries.get_conn_entry(total_conns))
}
fn remove_client(&mut self, db_user: (DbName, RoleName), conn_id: uuid::Uuid) -> bool {
let Self {
pools, total_conns, ..
} = self;
if let Some(pool) = pools.get_mut(&db_user) {
let old_len = pool.conns.len();
pool.conns.retain(|conn| conn.conn.conn_id != conn_id);
let new_len = pool.conns.len();
let removed = old_len - new_len;
if removed > 0 {
Metrics::get()
.proxy
.http_pool_opened_connections
.get_metric()
.dec_by(removed as i64);
}
*total_conns -= removed;
removed > 0
} else {
false
}
}
fn put(pool: &RwLock<Self>, conn_info: &ConnInfo, client: ClientInner<C>) {
let conn_id = client.conn_id;
if client.is_closed() {
info!(%conn_id, "local_pool: throwing away connection '{conn_info}' because connection is closed");
return;
}
let global_max_conn = pool.read().global_pool_size_max_conns;
if pool.read().total_conns >= global_max_conn {
info!(%conn_id, "local_pool: throwing away connection '{conn_info}' because pool is full");
return;
}
// return connection to the pool
let mut returned = false;
let mut per_db_size = 0;
let total_conns = {
let mut pool = pool.write();
if pool.total_conns < pool.max_conns {
let pool_entries = pool.pools.entry(conn_info.db_and_user()).or_default();
pool_entries.conns.push(ConnPoolEntry {
conn: client,
_last_access: std::time::Instant::now(),
});
returned = true;
per_db_size = pool_entries.conns.len();
pool.total_conns += 1;
Metrics::get()
.proxy
.http_pool_opened_connections
.get_metric()
.inc();
}
pool.total_conns
};
// do logging outside of the mutex
if returned {
info!(%conn_id, "local_pool: returning connection '{conn_info}' back to the pool, total_conns={total_conns}, for this (db, user)={per_db_size}");
} else {
info!(%conn_id, "local_pool: throwing away connection '{conn_info}' because pool is full, total_conns={total_conns}");
}
}
}
impl<C: ClientInnerExt> Drop for EndpointConnPool<C> {
fn drop(&mut self) {
if self.total_conns > 0 {
Metrics::get()
.proxy
.http_pool_opened_connections
.get_metric()
.dec_by(self.total_conns as i64);
}
}
}
pub(crate) struct DbUserConnPool<C: ClientInnerExt> {
conns: Vec<ConnPoolEntry<C>>,
}
impl<C: ClientInnerExt> Default for DbUserConnPool<C> {
fn default() -> Self {
Self { conns: Vec::new() }
}
}
impl<C: ClientInnerExt> DbUserConnPool<C> {
fn clear_closed_clients(&mut self, conns: &mut usize) -> usize {
let old_len = self.conns.len();
self.conns.retain(|conn| !conn.conn.is_closed());
let new_len = self.conns.len();
let removed = old_len - new_len;
*conns -= removed;
removed
}
fn get_conn_entry(&mut self, conns: &mut usize) -> Option<ConnPoolEntry<C>> {
let mut removed = self.clear_closed_clients(conns);
let conn = self.conns.pop();
if conn.is_some() {
*conns -= 1;
removed += 1;
}
Metrics::get()
.proxy
.http_pool_opened_connections
.get_metric()
.dec_by(removed as i64);
conn
}
}
pub(crate) struct LocalConnPool<C: ClientInnerExt> {
global_pool: RwLock<EndpointConnPool<C>>,
config: &'static crate::config::HttpConfig,
}
impl<C: ClientInnerExt> LocalConnPool<C> {
pub(crate) fn new(config: &'static crate::config::HttpConfig) -> Arc<Self> {
Arc::new(Self {
global_pool: RwLock::new(EndpointConnPool {
pools: HashMap::new(),
total_conns: 0,
max_conns: config.pool_options.max_conns_per_endpoint,
global_pool_size_max_conns: config.pool_options.max_total_conns,
}),
config,
})
}
pub(crate) fn get_idle_timeout(&self) -> Duration {
self.config.pool_options.idle_timeout
}
// pub(crate) fn shutdown(&self) {
// let mut pool = self.global_pool.write();
// pool.pools.clear();
// pool.total_conns = 0;
// }
pub(crate) fn get(
self: &Arc<Self>,
ctx: &RequestMonitoring,
conn_info: &ConnInfo,
) -> Result<Option<LocalClient<C>>, HttpConnError> {
let mut client: Option<ClientInner<C>> = None;
if let Some(entry) = self
.global_pool
.write()
.get_conn_entry(conn_info.db_and_user())
{
client = Some(entry.conn);
}
// ok return cached connection if found and establish a new one otherwise
if let Some(client) = client {
if client.is_closed() {
info!("local_pool: cached connection '{conn_info}' is closed, opening a new one");
return Ok(None);
}
tracing::Span::current().record("conn_id", tracing::field::display(client.conn_id));
tracing::Span::current().record(
"pid",
tracing::field::display(client.inner.get_process_id()),
);
info!(
cold_start_info = ColdStartInfo::HttpPoolHit.as_str(),
"local_pool: reusing connection '{conn_info}'"
);
client.session.send(ctx.session_id())?;
ctx.set_cold_start_info(ColdStartInfo::HttpPoolHit);
ctx.success();
return Ok(Some(LocalClient::new(
client,
conn_info.clone(),
Arc::downgrade(self),
)));
}
Ok(None)
}
}
pub(crate) fn poll_client(
global_pool: Arc<LocalConnPool<tokio_postgres::Client>>,
ctx: &RequestMonitoring,
conn_info: ConnInfo,
client: tokio_postgres::Client,
mut connection: tokio_postgres::Connection<Socket, NoTlsStream>,
conn_id: uuid::Uuid,
aux: MetricsAuxInfo,
) -> LocalClient<tokio_postgres::Client> {
let conn_gauge = Metrics::get().proxy.db_connections.guard(ctx.protocol());
let mut session_id = ctx.session_id();
let (tx, mut rx) = tokio::sync::watch::channel(session_id);
let span = info_span!(parent: None, "connection", %conn_id);
let cold_start_info = ctx.cold_start_info();
span.in_scope(|| {
info!(cold_start_info = cold_start_info.as_str(), %conn_info, %session_id, "new connection");
});
let pool = Arc::downgrade(&global_pool);
let pool_clone = pool.clone();
let db_user = conn_info.db_and_user();
let idle = global_pool.get_idle_timeout();
let cancel = CancellationToken::new();
let cancelled = cancel.clone().cancelled_owned();
tokio::spawn(
async move {
let _conn_gauge = conn_gauge;
let mut idle_timeout = pin!(tokio::time::sleep(idle));
let mut cancelled = pin!(cancelled);
poll_fn(move |cx| {
if cancelled.as_mut().poll(cx).is_ready() {
info!("connection dropped");
return Poll::Ready(())
}
match rx.has_changed() {
Ok(true) => {
session_id = *rx.borrow_and_update();
info!(%session_id, "changed session");
idle_timeout.as_mut().reset(Instant::now() + idle);
}
Err(_) => {
info!("connection dropped");
return Poll::Ready(())
}
_ => {}
}
// 5 minute idle connection timeout
if idle_timeout.as_mut().poll(cx).is_ready() {
idle_timeout.as_mut().reset(Instant::now() + idle);
info!("connection idle");
if let Some(pool) = pool.clone().upgrade() {
// remove client from pool - should close the connection if it's idle.
// does nothing if the client is currently checked-out and in-use
if pool.global_pool.write().remove_client(db_user.clone(), conn_id) {
info!("idle connection removed");
}
}
}
loop {
let message = ready!(connection.poll_message(cx));
match message {
Some(Ok(AsyncMessage::Notice(notice))) => {
info!(%session_id, "notice: {}", notice);
}
Some(Ok(AsyncMessage::Notification(notif))) => {
warn!(%session_id, pid = notif.process_id(), channel = notif.channel(), "notification received");
}
Some(Ok(_)) => {
warn!(%session_id, "unknown message");
}
Some(Err(e)) => {
error!(%session_id, "connection error: {}", e);
break
}
None => {
info!("connection closed");
break
}
}
}
// remove from connection pool
if let Some(pool) = pool.clone().upgrade() {
if pool.global_pool.write().remove_client(db_user.clone(), conn_id) {
info!("closed connection removed");
}
}
Poll::Ready(())
}).await;
}
.instrument(span));
let key = SigningKey::random(&mut OsRng);
let inner = ClientInner {
inner: client,
session: tx,
cancel,
aux,
conn_id,
key,
jti: 0,
};
LocalClient::new(inner, conn_info, pool_clone)
}
struct ClientInner<C: ClientInnerExt> {
inner: C,
session: tokio::sync::watch::Sender<uuid::Uuid>,
cancel: CancellationToken,
aux: MetricsAuxInfo,
conn_id: uuid::Uuid,
// needed for pg_session_jwt state
key: SigningKey,
jti: u64,
}
impl<C: ClientInnerExt> Drop for ClientInner<C> {
fn drop(&mut self) {
// on client drop, tell the conn to shut down
self.cancel.cancel();
}
}
impl<C: ClientInnerExt> ClientInner<C> {
pub(crate) fn is_closed(&self) -> bool {
self.inner.is_closed()
}
}
impl<C: ClientInnerExt> LocalClient<C> {
pub(crate) fn metrics(&self) -> Arc<MetricCounter> {
let aux = &self.inner.as_ref().unwrap().aux;
USAGE_METRICS.register(Ids {
endpoint_id: aux.endpoint_id,
branch_id: aux.branch_id,
})
}
}
pub(crate) struct LocalClient<C: ClientInnerExt> {
span: Span,
inner: Option<ClientInner<C>>,
conn_info: ConnInfo,
pool: Weak<LocalConnPool<C>>,
}
pub(crate) struct Discard<'a, C: ClientInnerExt> {
conn_info: &'a ConnInfo,
pool: &'a mut Weak<LocalConnPool<C>>,
}
impl<C: ClientInnerExt> LocalClient<C> {
pub(self) fn new(
inner: ClientInner<C>,
conn_info: ConnInfo,
pool: Weak<LocalConnPool<C>>,
) -> Self {
Self {
inner: Some(inner),
span: Span::current(),
conn_info,
pool,
}
}
pub(crate) fn inner(&mut self) -> (&mut C, Discard<'_, C>) {
let Self {
inner,
pool,
conn_info,
span: _,
} = self;
let inner = inner.as_mut().expect("client inner should not be removed");
(&mut inner.inner, Discard { conn_info, pool })
}
pub(crate) fn key(&self) -> &SigningKey {
let inner = &self
.inner
.as_ref()
.expect("client inner should not be removed");
&inner.key
}
}
impl LocalClient<tokio_postgres::Client> {
pub(crate) async fn set_jwt_session(&mut self, payload: &[u8]) -> Result<(), HttpConnError> {
let inner = self
.inner
.as_mut()
.expect("client inner should not be removed");
inner.jti += 1;
let kid = inner.inner.get_process_id();
let header = json!({"kid":kid}).to_string();
let mut payload = serde_json::from_slice::<serde_json::Map<String, Value>>(payload)
.map_err(HttpConnError::JwtPayloadError)?;
payload.insert("jti".to_string(), Value::Number(inner.jti.into()));
let payload = Value::Object(payload).to_string();
debug!(
kid,
jti = inner.jti,
?header,
?payload,
"signing new ephemeral JWT"
);
let token = sign_jwt(&inner.key, header, payload);
// initiates the auth session
inner.inner.simple_query("discard all").await?;
inner
.inner
.query(
"select auth.jwt_session_init($1)",
&[&token as &(dyn ToSql + Sync)],
)
.await?;
info!(kid, jti = inner.jti, "user session state init");
Ok(())
}
}
fn sign_jwt(sk: &SigningKey, header: String, payload: String) -> String {
let header = Base64UrlUnpadded::encode_string(header.as_bytes());
let payload = Base64UrlUnpadded::encode_string(payload.as_bytes());
let message = format!("{header}.{payload}");
let sig: Signature = sk.sign(message.as_bytes());
let base64_sig = Base64UrlUnpadded::encode_string(&sig.to_bytes());
format!("{message}.{base64_sig}")
}
impl<C: ClientInnerExt> Discard<'_, C> {
pub(crate) fn check_idle(&mut self, status: ReadyForQueryStatus) {
let conn_info = &self.conn_info;
if status != ReadyForQueryStatus::Idle && std::mem::take(self.pool).strong_count() > 0 {
info!(
"local_pool: throwing away connection '{conn_info}' because connection is not idle"
);
}
}
pub(crate) fn discard(&mut self) {
let conn_info = &self.conn_info;
if std::mem::take(self.pool).strong_count() > 0 {
info!("local_pool: throwing away connection '{conn_info}' because connection is potentially in a broken state");
}
}
}
impl<C: ClientInnerExt> LocalClient<C> {
pub fn get_client(&self) -> &C {
&self
.inner
.as_ref()
.expect("client inner should not be removed")
.inner
}
fn do_drop(&mut self) -> Option<impl FnOnce()> {
let conn_info = self.conn_info.clone();
let client = self
.inner
.take()
.expect("client inner should not be removed");
if let Some(conn_pool) = std::mem::take(&mut self.pool).upgrade() {
let current_span = self.span.clone();
// return connection to the pool
return Some(move || {
let _span = current_span.enter();
EndpointConnPool::put(&conn_pool.global_pool, &conn_info, client);
});
}
None
}
}
impl<C: ClientInnerExt> Drop for LocalClient<C> {
fn drop(&mut self) {
if let Some(drop) = self.do_drop() {
tokio::task::spawn_blocking(drop);
}
}
}

View File

@@ -8,6 +8,7 @@ mod conn_pool;
mod http_conn_pool;
mod http_util;
mod json;
mod local_conn_pool;
mod sql_over_http;
mod websocket;
@@ -63,6 +64,7 @@ pub async fn task_main(
info!("websocket server has shut down");
}
let local_pool = local_conn_pool::LocalConnPool::new(&config.http_config);
let conn_pool = conn_pool::GlobalConnPool::new(&config.http_config);
{
let conn_pool = Arc::clone(&conn_pool);
@@ -105,6 +107,7 @@ pub async fn task_main(
let backend = Arc::new(PoolingBackend {
http_conn_pool: Arc::clone(&http_conn_pool),
local_pool,
pool: Arc::clone(&conn_pool),
config,
endpoint_rate_limiter: Arc::clone(&endpoint_rate_limiter),

View File

@@ -40,7 +40,7 @@ use url::Url;
use urlencoding;
use utils::http::error::ApiError;
use crate::auth::backend::ComputeCredentials;
use crate::auth::backend::ComputeCredentialKeys;
use crate::auth::backend::ComputeUserInfo;
use crate::auth::endpoint_sni;
use crate::auth::ComputeUserInfoParseError;
@@ -56,20 +56,22 @@ use crate::metrics::Metrics;
use crate::proxy::run_until_cancelled;
use crate::proxy::NeonOptions;
use crate::serverless::backend::HttpConnError;
use crate::usage_metrics::MetricCounter;
use crate::usage_metrics::MetricCounterRecorder;
use crate::DbName;
use crate::RoleName;
use super::backend::LocalProxyConnError;
use super::backend::PoolingBackend;
use super::conn_pool;
use super::conn_pool::AuthData;
use super::conn_pool::Client;
use super::conn_pool::ConnInfo;
use super::conn_pool::ConnInfoWithAuth;
use super::http_util::json_response;
use super::json::json_to_pg_text;
use super::json::pg_text_row_to_json;
use super::json::JsonConversionError;
use super::local_conn_pool;
#[derive(serde::Deserialize)]
#[serde(rename_all = "camelCase")]
@@ -620,6 +622,9 @@ async fn handle_db_inner(
let authenticate_and_connect = Box::pin(
async {
let is_local_proxy =
matches!(backend.config.auth_backend, crate::auth::Backend::Local(_));
let keys = match auth {
AuthData::Password(pw) => {
backend
@@ -639,18 +644,24 @@ async fn handle_db_inner(
&conn_info.user_info,
jwt,
)
.await?;
ComputeCredentials {
info: conn_info.user_info.clone(),
keys: crate::auth::backend::ComputeCredentialKeys::None,
}
.await?
}
};
let client = match keys.keys {
ComputeCredentialKeys::JwtPayload(payload) if is_local_proxy => {
let mut client = backend.connect_to_local_postgres(ctx, conn_info).await?;
client.set_jwt_session(&payload).await?;
Client::Local(client)
}
_ => {
let client = backend
.connect_to_compute(ctx, conn_info, keys, !allow_pool)
.await?;
Client::Remote(client)
}
};
let client = backend
.connect_to_compute(ctx, conn_info, keys, !allow_pool)
.await?;
// not strictly necessary to mark success here,
// but it's just insurance for if we forget it somewhere else
ctx.success();
@@ -791,7 +802,7 @@ impl QueryData {
self,
config: &'static ProxyConfig,
cancel: CancellationToken,
client: &mut Client<tokio_postgres::Client>,
client: &mut Client,
parsed_headers: HttpHeaders,
) -> Result<String, SqlOverHttpError> {
let (inner, mut discard) = client.inner();
@@ -865,7 +876,7 @@ impl BatchQueryData {
self,
config: &'static ProxyConfig,
cancel: CancellationToken,
client: &mut Client<tokio_postgres::Client>,
client: &mut Client,
parsed_headers: HttpHeaders,
) -> Result<String, SqlOverHttpError> {
info!("starting transaction");
@@ -1058,3 +1069,50 @@ async fn query_to_json<T: GenericClient>(
Ok((ready, results))
}
enum Client {
Remote(conn_pool::Client<tokio_postgres::Client>),
Local(local_conn_pool::LocalClient<tokio_postgres::Client>),
}
enum Discard<'a> {
Remote(conn_pool::Discard<'a, tokio_postgres::Client>),
Local(local_conn_pool::Discard<'a, tokio_postgres::Client>),
}
impl Client {
fn metrics(&self) -> Arc<MetricCounter> {
match self {
Client::Remote(client) => client.metrics(),
Client::Local(local_client) => local_client.metrics(),
}
}
fn inner(&mut self) -> (&mut tokio_postgres::Client, Discard<'_>) {
match self {
Client::Remote(client) => {
let (c, d) = client.inner();
(c, Discard::Remote(d))
}
Client::Local(local_client) => {
let (c, d) = local_client.inner();
(c, Discard::Local(d))
}
}
}
}
impl Discard<'_> {
fn check_idle(&mut self, status: ReadyForQueryStatus) {
match self {
Discard::Remote(discard) => discard.check_idle(status),
Discard::Local(discard) => discard.check_idle(status),
}
}
fn discard(&mut self) {
match self {
Discard::Remote(discard) => discard.discard(),
Discard::Local(discard) => discard.discard(),
}
}
}

View File

@@ -7,7 +7,6 @@ import json
import os
import re
import timeit
from collections.abc import Iterator
from contextlib import contextmanager
from datetime import datetime
from pathlib import Path
@@ -25,7 +24,8 @@ from fixtures.log_helper import log
from fixtures.neon_fixtures import NeonPageserver
if TYPE_CHECKING:
from typing import Callable, ClassVar, Optional
from collections.abc import Iterator, Mapping
from typing import Callable, Optional
"""
@@ -141,6 +141,28 @@ class PgBenchRunResult:
)
# Taken from https://github.com/postgres/postgres/blob/REL_15_1/src/bin/pgbench/pgbench.c#L5144-L5171
#
# This used to be a class variable on PgBenchInitResult. However later versions
# of Python complain:
#
# ValueError: mutable default <class 'dict'> for field EXTRACTORS is not allowed: use default_factory
#
# When you do what the error tells you to do, it seems to fail our Python 3.9
# test environment. So let's just move it to a private module constant, and move
# on.
_PGBENCH_INIT_EXTRACTORS: Mapping[str, re.Pattern[str]] = {
"drop_tables": re.compile(r"drop tables (\d+\.\d+) s"),
"create_tables": re.compile(r"create tables (\d+\.\d+) s"),
"client_side_generate": re.compile(r"client-side generate (\d+\.\d+) s"),
"server_side_generate": re.compile(r"server-side generate (\d+\.\d+) s"),
"vacuum": re.compile(r"vacuum (\d+\.\d+) s"),
"primary_keys": re.compile(r"primary keys (\d+\.\d+) s"),
"foreign_keys": re.compile(r"foreign keys (\d+\.\d+) s"),
"total": re.compile(r"done in (\d+\.\d+) s"), # Total time printed by pgbench
}
@dataclasses.dataclass
class PgBenchInitResult:
total: Optional[float]
@@ -155,20 +177,6 @@ class PgBenchInitResult:
start_timestamp: int
end_timestamp: int
# Taken from https://github.com/postgres/postgres/blob/REL_15_1/src/bin/pgbench/pgbench.c#L5144-L5171
EXTRACTORS: ClassVar[dict[str, re.Pattern[str]]] = dataclasses.field(
default_factory=lambda: {
"drop_tables": re.compile(r"drop tables (\d+\.\d+) s"),
"create_tables": re.compile(r"create tables (\d+\.\d+) s"),
"client_side_generate": re.compile(r"client-side generate (\d+\.\d+) s"),
"server_side_generate": re.compile(r"server-side generate (\d+\.\d+) s"),
"vacuum": re.compile(r"vacuum (\d+\.\d+) s"),
"primary_keys": re.compile(r"primary keys (\d+\.\d+) s"),
"foreign_keys": re.compile(r"foreign keys (\d+\.\d+) s"),
"total": re.compile(r"done in (\d+\.\d+) s"), # Total time printed by pgbench
}
)
@classmethod
def parse_from_stderr(
cls,
@@ -185,7 +193,7 @@ class PgBenchInitResult:
timings: dict[str, Optional[float]] = {}
last_line_items = re.split(r"\(|\)|,", last_line)
for item in last_line_items:
for key, regex in cls.EXTRACTORS.items():
for key, regex in _PGBENCH_INIT_EXTRACTORS.items():
if (m := regex.match(item.strip())) is not None:
if key in timings:
raise RuntimeError(

View File

@@ -6,6 +6,8 @@ from enum import Enum
from functools import total_ordering
from typing import TYPE_CHECKING, TypeVar
from typing_extensions import override
if TYPE_CHECKING:
from typing import Any, Union
@@ -31,33 +33,36 @@ class Lsn:
self.lsn_int = (int(left, 16) << 32) + int(right, 16)
assert 0 <= self.lsn_int <= 0xFFFFFFFF_FFFFFFFF
@override
def __str__(self) -> str:
"""Convert lsn from int to standard hex notation."""
return f"{(self.lsn_int >> 32):X}/{(self.lsn_int & 0xFFFFFFFF):X}"
@override
def __repr__(self) -> str:
return f'Lsn("{str(self)}")'
def __int__(self) -> int:
return self.lsn_int
def __lt__(self, other: Any) -> bool:
def __lt__(self, other: object) -> bool:
if not isinstance(other, Lsn):
return NotImplemented
return self.lsn_int < other.lsn_int
def __gt__(self, other: Any) -> bool:
def __gt__(self, other: object) -> bool:
if not isinstance(other, Lsn):
raise NotImplementedError
return self.lsn_int > other.lsn_int
def __eq__(self, other: Any) -> bool:
@override
def __eq__(self, other: object) -> bool:
if not isinstance(other, Lsn):
return NotImplemented
return self.lsn_int == other.lsn_int
# Returns the difference between two Lsns, in bytes
def __sub__(self, other: Any) -> int:
def __sub__(self, other: object) -> int:
if not isinstance(other, Lsn):
return NotImplemented
return self.lsn_int - other.lsn_int
@@ -70,6 +75,7 @@ class Lsn:
else:
raise NotImplementedError
@override
def __hash__(self) -> int:
return hash(self.lsn_int)
@@ -116,19 +122,22 @@ class Id:
self.id = bytearray.fromhex(x)
assert len(self.id) == 16
@override
def __str__(self) -> str:
return self.id.hex()
def __lt__(self, other) -> bool:
def __lt__(self, other: object) -> bool:
if not isinstance(other, type(self)):
return NotImplemented
return self.id < other.id
def __eq__(self, other) -> bool:
@override
def __eq__(self, other: object) -> bool:
if not isinstance(other, type(self)):
return NotImplemented
return self.id == other.id
@override
def __hash__(self) -> int:
return hash(str(self.id))
@@ -139,25 +148,31 @@ class Id:
class TenantId(Id):
@override
def __repr__(self) -> str:
return f'`TenantId("{self.id.hex()}")'
@override
def __str__(self) -> str:
return self.id.hex()
class NodeId(Id):
@override
def __repr__(self) -> str:
return f'`NodeId("{self.id.hex()}")'
@override
def __str__(self) -> str:
return self.id.hex()
class TimelineId(Id):
@override
def __repr__(self) -> str:
return f'TimelineId("{self.id.hex()}")'
@override
def __str__(self) -> str:
return self.id.hex()
@@ -187,7 +202,7 @@ class TenantShardId:
assert self.shard_number < self.shard_count or self.shard_count == 0
@classmethod
def parse(cls: type[TTenantShardId], input) -> TTenantShardId:
def parse(cls: type[TTenantShardId], input: str) -> TTenantShardId:
if len(input) == 32:
return cls(
tenant_id=TenantId(input),
@@ -203,6 +218,7 @@ class TenantShardId:
else:
raise ValueError(f"Invalid TenantShardId '{input}'")
@override
def __str__(self):
if self.shard_count > 0:
return f"{self.tenant_id}-{self.shard_number:02x}{self.shard_count:02x}"
@@ -210,22 +226,25 @@ class TenantShardId:
# Unsharded case: equivalent of Rust TenantShardId::unsharded(tenant_id)
return str(self.tenant_id)
@override
def __repr__(self):
return self.__str__()
def _tuple(self) -> tuple[TenantId, int, int]:
return (self.tenant_id, self.shard_number, self.shard_count)
def __lt__(self, other) -> bool:
def __lt__(self, other: object) -> bool:
if not isinstance(other, type(self)):
return NotImplemented
return self._tuple() < other._tuple()
def __eq__(self, other) -> bool:
@override
def __eq__(self, other: object) -> bool:
if not isinstance(other, type(self)):
return NotImplemented
return self._tuple() == other._tuple()
@override
def __hash__(self) -> int:
return hash(self._tuple())

View File

@@ -8,9 +8,11 @@ from contextlib import _GeneratorContextManager, contextmanager
# Type-related stuff
from pathlib import Path
from typing import TYPE_CHECKING
import pytest
from _pytest.fixtures import FixtureRequest
from typing_extensions import override
from fixtures.benchmark_fixture import MetricReport, NeonBenchmarker
from fixtures.log_helper import log
@@ -24,6 +26,9 @@ from fixtures.neon_fixtures import (
)
from fixtures.pg_stats import PgStatTable
if TYPE_CHECKING:
from collections.abc import Iterator
class PgCompare(ABC):
"""Common interface of all postgres implementations, useful for benchmarks.
@@ -65,12 +70,12 @@ class PgCompare(ABC):
@contextmanager
@abstractmethod
def record_pageserver_writes(self, out_name):
def record_pageserver_writes(self, out_name: str):
pass
@contextmanager
@abstractmethod
def record_duration(self, out_name):
def record_duration(self, out_name: str):
pass
@contextmanager
@@ -122,28 +127,34 @@ class NeonCompare(PgCompare):
self._pg = self.env.endpoints.create_start("main", "main", self.tenant)
@property
@override
def pg(self) -> PgProtocol:
return self._pg
@property
@override
def zenbenchmark(self) -> NeonBenchmarker:
return self._zenbenchmark
@property
@override
def pg_bin(self) -> PgBin:
return self._pg_bin
@override
def flush(self, compact: bool = True, gc: bool = True):
wait_for_last_flush_lsn(self.env, self._pg, self.tenant, self.timeline)
self.pageserver_http_client.timeline_checkpoint(self.tenant, self.timeline, compact=compact)
if gc:
self.pageserver_http_client.timeline_gc(self.tenant, self.timeline, 0)
@override
def compact(self):
self.pageserver_http_client.timeline_compact(
self.tenant, self.timeline, wait_until_uploaded=True
)
@override
def report_peak_memory_use(self):
self.zenbenchmark.record(
"peak_mem",
@@ -152,6 +163,7 @@ class NeonCompare(PgCompare):
report=MetricReport.LOWER_IS_BETTER,
)
@override
def report_size(self):
timeline_size = self.zenbenchmark.get_timeline_size(
self.env.repo_dir, self.tenant, self.timeline
@@ -185,9 +197,11 @@ class NeonCompare(PgCompare):
"num_files_uploaded", total_files, "", report=MetricReport.LOWER_IS_BETTER
)
@override
def record_pageserver_writes(self, out_name: str) -> _GeneratorContextManager[None]:
return self.zenbenchmark.record_pageserver_writes(self.env.pageserver, out_name)
@override
def record_duration(self, out_name: str) -> _GeneratorContextManager[None]:
return self.zenbenchmark.record_duration(out_name)
@@ -211,26 +225,33 @@ class VanillaCompare(PgCompare):
self.cur = self.conn.cursor()
@property
@override
def pg(self) -> VanillaPostgres:
return self._pg
@property
@override
def zenbenchmark(self) -> NeonBenchmarker:
return self._zenbenchmark
@property
@override
def pg_bin(self) -> PgBin:
return self._pg.pg_bin
@override
def flush(self, compact: bool = False, gc: bool = False):
self.cur.execute("checkpoint")
@override
def compact(self):
pass
@override
def report_peak_memory_use(self):
pass # TODO find something
@override
def report_size(self):
data_size = self.pg.get_subdir_size(Path("base"))
self.zenbenchmark.record(
@@ -245,6 +266,7 @@ class VanillaCompare(PgCompare):
def record_pageserver_writes(self, out_name: str) -> Iterator[None]:
yield # Do nothing
@override
def record_duration(self, out_name: str) -> _GeneratorContextManager[None]:
return self.zenbenchmark.record_duration(out_name)
@@ -261,28 +283,35 @@ class RemoteCompare(PgCompare):
self.cur = self.conn.cursor()
@property
@override
def pg(self) -> PgProtocol:
return self._pg
@property
@override
def zenbenchmark(self) -> NeonBenchmarker:
return self._zenbenchmark
@property
@override
def pg_bin(self) -> PgBin:
return self._pg.pg_bin
def flush(self):
@override
def flush(self, compact: bool = False, gc: bool = False):
# TODO: flush the remote pageserver
pass
@override
def compact(self):
pass
@override
def report_peak_memory_use(self):
# TODO: get memory usage from remote pageserver
pass
@override
def report_size(self):
# TODO: get storage size from remote pageserver
pass
@@ -291,6 +320,7 @@ class RemoteCompare(PgCompare):
def record_pageserver_writes(self, out_name: str) -> Iterator[None]:
yield # Do nothing
@override
def record_duration(self, out_name: str) -> _GeneratorContextManager[None]:
return self.zenbenchmark.record_duration(out_name)

View File

@@ -1,27 +1,31 @@
from __future__ import annotations
import concurrent.futures
from typing import Any
from typing import TYPE_CHECKING
import pytest
from pytest_httpserver import HTTPServer
from werkzeug.wrappers.request import Request
from werkzeug.wrappers.response import Response
from fixtures.common_types import TenantId
from fixtures.log_helper import log
if TYPE_CHECKING:
from typing import Any, Callable, Optional
class ComputeReconfigure:
def __init__(self, server):
def __init__(self, server: HTTPServer):
self.server = server
self.control_plane_compute_hook_api = f"http://{server.host}:{server.port}/notify-attach"
self.workloads = {}
self.on_notify = None
self.workloads: dict[TenantId, Any] = {}
self.on_notify: Optional[Callable[[Any], None]] = None
def register_workload(self, workload):
def register_workload(self, workload: Any):
self.workloads[workload.tenant_id] = workload
def register_on_notify(self, fn):
def register_on_notify(self, fn: Optional[Callable[[Any], None]]):
"""
Add some extra work during a notification, like sleeping to slow things down, or
logging what was notified.
@@ -30,7 +34,7 @@ class ComputeReconfigure:
@pytest.fixture(scope="function")
def compute_reconfigure_listener(make_httpserver):
def compute_reconfigure_listener(make_httpserver: HTTPServer):
"""
This fixture exposes an HTTP listener for the storage controller to submit
compute notifications to us, instead of updating neon_local endpoints itself.
@@ -48,7 +52,7 @@ def compute_reconfigure_listener(make_httpserver):
# accept a healthy rate of calls into notify-attach.
reconfigure_threads = concurrent.futures.ThreadPoolExecutor(max_workers=1)
def handler(request: Request):
def handler(request: Request) -> Response:
assert request.json is not None
body: dict[str, Any] = request.json
log.info(f"notify-attach request: {body}")

View File

@@ -14,8 +14,10 @@ from allure_pytest.utils import allure_name, allure_suite_labels
from fixtures.log_helper import log
if TYPE_CHECKING:
from collections.abc import MutableMapping
from typing import Any
"""
The plugin reruns flaky tests.
It uses `pytest.mark.flaky` provided by `pytest-rerunfailures` plugin and flaky tests detected by `scripts/flaky_tests.py`

View File

@@ -1,8 +1,15 @@
from __future__ import annotations
from typing import TYPE_CHECKING
import pytest
from pytest_httpserver import HTTPServer
if TYPE_CHECKING:
from collections.abc import Iterator
from fixtures.port_distributor import PortDistributor
# TODO: mypy fails with:
# Module "fixtures.neon_fixtures" does not explicitly export attribute "PortDistributor" [attr-defined]
# from fixtures.neon_fixtures import PortDistributor
@@ -17,7 +24,7 @@ def httpserver_ssl_context():
@pytest.fixture(scope="function")
def make_httpserver(httpserver_listen_address, httpserver_ssl_context):
def make_httpserver(httpserver_listen_address, httpserver_ssl_context) -> Iterator[HTTPServer]:
host, port = httpserver_listen_address
if not host:
host = HTTPServer.DEFAULT_LISTEN_HOST
@@ -33,13 +40,13 @@ def make_httpserver(httpserver_listen_address, httpserver_ssl_context):
@pytest.fixture(scope="function")
def httpserver(make_httpserver):
def httpserver(make_httpserver: HTTPServer) -> Iterator[HTTPServer]:
server = make_httpserver
yield server
server.clear()
@pytest.fixture(scope="function")
def httpserver_listen_address(port_distributor) -> tuple[str, int]:
def httpserver_listen_address(port_distributor: PortDistributor) -> tuple[str, int]:
port = port_distributor.get_port()
return ("localhost", port)

View File

@@ -24,25 +24,14 @@ https://docs.pytest.org/en/6.2.x/logging.html
# log format is specified in pytest.ini file
LOGGING = {
"version": 1,
"filters": {
"wzfilter": {
"()": "fixtures.log_helper_internal.WerkzeugNoiseFilter",
},
},
"loggers": {
"root": {"level": "INFO"},
"root.safekeeper_async": {"level": "INFO"}, # a lot of logs on DEBUG level
# Use a custom filter to make werkzeug's messages less verbose.
"werkzeug": {
"filters": ["wzfilter"],
"level": "INFO",
},
},
}
def getLogger(name="root") -> logging.Logger:
def getLogger(name: str = "root") -> logging.Logger:
"""Method to get logger for tests.
Should be used to get correctly initialized logger."""

View File

@@ -1,24 +0,0 @@
# These are logically part of in log_helper.py, but need to be in a
# different file because these get loaded from the logging config
# file. If you try to included these in log_helper.py directly, you
# get an error about circular dependency.
import re
class WerkzeugNoiseFilter(object):
"""Moto server that we use for mocking S3 uses werkzeug, which
logs all HTTP operations. It constructs log messages like this:
127.0.0.1 - - [08/Oct/2024 12:43:46] "PUT /bucket-name/path?x-id=PutObject HTTP/1.1" 200 -
The IP address is not interesting in tests, as it's always just
127.0.0.1. And the timestamp is redundant with the timestamp we
print for all log messages anyway, with millisecond precision.
Unfortunately those are "etched" in the message, and cannot be
overriden by setting a custom formatter. To reduce the noise in
the test output, this filter removes those fields from the log
messages.
"""
def filter(self, logRecord):
logRecord.msg = re.sub(r'127\.0\.0\.1 - - \[.+\] (".*".*)', r'\1', logRecord.msg)
return True

View File

@@ -22,7 +22,7 @@ class Metrics:
def query_all(self, name: str, filter: Optional[dict[str, str]] = None) -> list[Sample]:
filter = filter or {}
res = []
res: list[Sample] = []
for sample in self.metrics[name]:
try:
@@ -59,7 +59,7 @@ class MetricsGetter:
return results[0].value
def get_metrics_values(
self, names: list[str], filter: Optional[dict[str, str]] = None, absence_ok=False
self, names: list[str], filter: Optional[dict[str, str]] = None, absence_ok: bool = False
) -> dict[str, float]:
"""
When fetching multiple named metrics, it is more efficient to use this

View File

@@ -6,7 +6,7 @@ from typing import TYPE_CHECKING, cast
import requests
if TYPE_CHECKING:
from typing import Any, Literal, Optional, Union
from typing import Any, Literal, Optional
from fixtures.pg_version import PgVersion
@@ -25,9 +25,7 @@ class NeonAPI:
self.__neon_api_key = neon_api_key
self.__neon_api_base_url = neon_api_base_url.strip("/")
def __request(
self, method: Union[str, bytes], endpoint: str, **kwargs: Any
) -> requests.Response:
def __request(self, method: str | bytes, endpoint: str, **kwargs: Any) -> requests.Response:
if "headers" not in kwargs:
kwargs["headers"] = {}
kwargs["headers"]["Authorization"] = f"Bearer {self.__neon_api_key}"
@@ -187,8 +185,8 @@ class NeonAPI:
def get_connection_uri(
self,
project_id: str,
branch_id: Optional[str] = None,
endpoint_id: Optional[str] = None,
branch_id: str | None = None,
endpoint_id: str | None = None,
database_name: str = "neondb",
role_name: str = "neondb_owner",
pooled: bool = True,
@@ -264,7 +262,7 @@ class NeonAPI:
class NeonApiEndpoint:
def __init__(self, neon_api: NeonAPI, pg_version: PgVersion, project_id: Optional[str]):
def __init__(self, neon_api: NeonAPI, pg_version: PgVersion, project_id: str | None):
self.neon_api = neon_api
if project_id is None:
project = neon_api.create_project(pg_version)

View File

@@ -3657,7 +3657,7 @@ class Endpoint(PgProtocol, LogUtils):
config_lines: Optional[list[str]] = None,
remote_ext_config: Optional[str] = None,
pageserver_id: Optional[int] = None,
allow_multiple=False,
allow_multiple: bool = False,
basebackup_request_tries: Optional[int] = None,
) -> Endpoint:
"""
@@ -3998,7 +3998,7 @@ class Safekeeper(LogUtils):
def timeline_dir(self, tenant_id, timeline_id) -> Path:
return self.data_dir / str(tenant_id) / str(timeline_id)
# List partial uploaded segments of this safekeeper. Works only for
# list partial uploaded segments of this safekeeper. Works only for
# RemoteStorageKind.LOCAL_FS.
def list_uploaded_segments(self, tenant_id: TenantId, timeline_id: TimelineId):
tline_path = (
@@ -4293,7 +4293,7 @@ def pytest_addoption(parser: Parser):
)
SMALL_DB_FILE_NAME_REGEX: re.Pattern = re.compile( # type: ignore[type-arg]
SMALL_DB_FILE_NAME_REGEX: re.Pattern[str] = re.compile(
r"config-v1|heatmap-v1|metadata|.+\.(?:toml|pid|json|sql|conf)"
)

View File

@@ -1,10 +1,13 @@
from __future__ import annotations
from collections.abc import Iterator
from pathlib import Path
from typing import TYPE_CHECKING
import psutil
if TYPE_CHECKING:
from collections.abc import Iterator
def iter_mounts_beneath(topdir: Path) -> Iterator[Path]:
"""

View File

@@ -9,7 +9,12 @@ import toml
from _pytest.python import Metafunc
from fixtures.pg_version import PgVersion
from fixtures.utils import AuxFileStore
if TYPE_CHECKING:
from typing import Any, Optional
from fixtures.utils import AuxFileStore
if TYPE_CHECKING:
from typing import Any, Optional

View File

@@ -2,9 +2,14 @@ from __future__ import annotations
import enum
import os
from typing import Optional
from typing import TYPE_CHECKING
import pytest
from typing_extensions import override
if TYPE_CHECKING:
from typing import Optional
"""
This fixture is used to determine which version of Postgres to use for tests.
@@ -24,10 +29,12 @@ class PgVersion(str, enum.Enum):
NOT_SET = "<-POSTRGRES VERSION IS NOT SET->"
# Make it less confusing in logs
@override
def __repr__(self) -> str:
return f"'{self.value}'"
# Make this explicit for Python 3.11 compatibility, which changes the behavior of enums
@override
def __str__(self) -> str:
return self.value
@@ -38,7 +45,8 @@ class PgVersion(str, enum.Enum):
return f"v{self.value}"
@classmethod
def _missing_(cls, value) -> Optional[PgVersion]:
@override
def _missing_(cls, value: object) -> Optional[PgVersion]:
known_values = {v.value for _, v in cls.__members__.items()}
# Allow passing version as a string with "v" prefix (e.g. "v14")

View File

@@ -59,10 +59,7 @@ class PortDistributor:
if isinstance(value, int):
return self._replace_port_int(value)
if isinstance(value, str):
return self._replace_port_str(value)
raise TypeError(f"unsupported type {type(value)} of {value=}")
return self._replace_port_str(value)
def _replace_port_int(self, value: int) -> int:
known_port = self.port_map.get(value)
@@ -75,7 +72,7 @@ class PortDistributor:
# Use regex to find port in a string
# urllib.parse.urlparse produces inconvenient results for cases without scheme like "localhost:5432"
# See https://bugs.python.org/issue27657
ports = re.findall(r":(\d+)(?:/|$)", value)
ports: list[str] = re.findall(r":(\d+)(?:/|$)", value)
assert len(ports) == 1, f"can't find port in {value}"
port_int = int(ports[0])

View File

@@ -13,6 +13,7 @@ import boto3
import toml
from moto.server import ThreadedMotoServer
from mypy_boto3_s3 import S3Client
from typing_extensions import override
from fixtures.common_types import TenantId, TenantShardId, TimelineId
from fixtures.log_helper import log
@@ -36,6 +37,7 @@ class RemoteStorageUser(str, enum.Enum):
EXTENSIONS = "ext"
SAFEKEEPER = "safekeeper"
@override
def __str__(self) -> str:
return self.value
@@ -81,11 +83,13 @@ class LocalFsStorage:
def timeline_path(self, tenant_id: TenantId, timeline_id: TimelineId) -> Path:
return self.tenant_path(tenant_id) / "timelines" / str(timeline_id)
def timeline_latest_generation(self, tenant_id, timeline_id):
def timeline_latest_generation(
self, tenant_id: TenantId, timeline_id: TimelineId
) -> Optional[int]:
timeline_files = os.listdir(self.timeline_path(tenant_id, timeline_id))
index_parts = [f for f in timeline_files if f.startswith("index_part")]
def parse_gen(filename):
def parse_gen(filename: str) -> Optional[int]:
log.info(f"parsing index_part '{filename}'")
parts = filename.split("-")
if len(parts) == 2:
@@ -93,7 +97,7 @@ class LocalFsStorage:
else:
return None
generations = sorted([parse_gen(f) for f in index_parts])
generations = sorted([parse_gen(f) for f in index_parts]) # type: ignore
if len(generations) == 0:
raise RuntimeError(f"No index_part found for {tenant_id}/{timeline_id}")
return generations[-1]
@@ -122,14 +126,14 @@ class LocalFsStorage:
filename = f"{local_name}-{generation:08x}"
return self.timeline_path(tenant_id, timeline_id) / filename
def index_content(self, tenant_id: TenantId, timeline_id: TimelineId):
def index_content(self, tenant_id: TenantId, timeline_id: TimelineId) -> Any:
with self.index_path(tenant_id, timeline_id).open("r") as f:
return json.load(f)
def heatmap_path(self, tenant_id: TenantId) -> Path:
return self.tenant_path(tenant_id) / TENANT_HEATMAP_FILE_NAME
def heatmap_content(self, tenant_id):
def heatmap_content(self, tenant_id: TenantId) -> Any:
with self.heatmap_path(tenant_id).open("r") as f:
return json.load(f)
@@ -297,7 +301,7 @@ class S3Storage:
def heatmap_key(self, tenant_id: TenantId) -> str:
return f"{self.tenant_path(tenant_id)}/{TENANT_HEATMAP_FILE_NAME}"
def heatmap_content(self, tenant_id: TenantId):
def heatmap_content(self, tenant_id: TenantId) -> Any:
r = self.client.get_object(Bucket=self.bucket_name, Key=self.heatmap_key(tenant_id))
return json.loads(r["Body"].read().decode("utf-8"))
@@ -317,7 +321,7 @@ class RemoteStorageKind(str, enum.Enum):
def configure(
self,
repo_dir: Path,
mock_s3_server,
mock_s3_server: MockS3Server,
run_id: str,
test_name: str,
user: RemoteStorageUser,
@@ -451,15 +455,9 @@ def default_remote_storage() -> RemoteStorageKind:
def remote_storage_to_toml_dict(remote_storage: RemoteStorage) -> dict[str, Any]:
if not isinstance(remote_storage, (LocalFsStorage, S3Storage)):
raise Exception("invalid remote storage type")
return remote_storage.to_toml_dict()
# serialize as toml inline table
def remote_storage_to_toml_inline_table(remote_storage: RemoteStorage) -> str:
if not isinstance(remote_storage, (LocalFsStorage, S3Storage)):
raise Exception("invalid remote storage type")
return remote_storage.to_toml_inline_table()

View File

@@ -1,7 +1,7 @@
from __future__ import annotations
import re
from typing import Any, Optional
from typing import TYPE_CHECKING
import pytest
import requests
@@ -12,6 +12,9 @@ from werkzeug.wrappers.response import Response
from fixtures.log_helper import log
if TYPE_CHECKING:
from typing import Any, Optional
class StorageControllerProxy:
def __init__(self, server: HTTPServer):
@@ -34,7 +37,7 @@ def proxy_request(method: str, url: str, **kwargs) -> requests.Response:
@pytest.fixture(scope="function")
def storage_controller_proxy(make_httpserver):
def storage_controller_proxy(make_httpserver: HTTPServer):
"""
Proxies requests into the storage controller to the currently
selected storage controller instance via `StorageControllerProxy.route_to`.
@@ -48,7 +51,7 @@ def storage_controller_proxy(make_httpserver):
log.info(f"Storage controller proxy listening on {self.listen}")
def handler(request: Request):
def handler(request: Request) -> Response:
if self.route_to is None:
log.info(f"Storage controller proxy has no routing configured for {request.url}")
return Response("Routing not configured", status=503)

View File

@@ -18,6 +18,7 @@ from urllib.parse import urlencode
import allure
import zstandard
from psycopg2.extensions import cursor
from typing_extensions import override
from fixtures.log_helper import log
from fixtures.pageserver.common_types import (
@@ -26,14 +27,14 @@ from fixtures.pageserver.common_types import (
)
if TYPE_CHECKING:
from typing import (
IO,
Optional,
Union,
)
from collections.abc import Iterable
from typing import IO, Optional
from fixtures.common_types import TimelineId
from fixtures.neon_fixtures import PgBin
from fixtures.common_types import TimelineId
WaitUntilRet = TypeVar("WaitUntilRet")
Fn = TypeVar("Fn", bound=Callable[..., Any])
@@ -42,12 +43,12 @@ def subprocess_capture(
capture_dir: Path,
cmd: list[str],
*,
check=False,
echo_stderr=False,
echo_stdout=False,
capture_stdout=False,
timeout=None,
with_command_header=True,
check: bool = False,
echo_stderr: bool = False,
echo_stdout: bool = False,
capture_stdout: bool = False,
timeout: Optional[float] = None,
with_command_header: bool = True,
**popen_kwargs: Any,
) -> tuple[str, Optional[str], int]:
"""Run a process and bifurcate its output to files and the `log` logger
@@ -84,6 +85,7 @@ def subprocess_capture(
self.capture = capture
self.captured = ""
@override
def run(self):
first = with_command_header
for line in self.in_file:
@@ -165,10 +167,10 @@ def global_counter() -> int:
def print_gc_result(row: dict[str, Any]):
log.info("GC duration {elapsed} ms".format_map(row))
log.info(
" total: {layers_total}, needed_by_cutoff {layers_needed_by_cutoff}, needed_by_pitr {layers_needed_by_pitr}"
" needed_by_branches: {layers_needed_by_branches}, not_updated: {layers_not_updated}, removed: {layers_removed}".format_map(
row
)
(
" total: {layers_total}, needed_by_cutoff {layers_needed_by_cutoff}, needed_by_pitr {layers_needed_by_pitr}"
" needed_by_branches: {layers_needed_by_branches}, not_updated: {layers_not_updated}, removed: {layers_removed}"
).format_map(row)
)
@@ -226,7 +228,7 @@ def get_scale_for_db(size_mb: int) -> int:
return round(0.06689 * size_mb - 0.5)
ATTACHMENT_NAME_REGEX: re.Pattern = re.compile( # type: ignore[type-arg]
ATTACHMENT_NAME_REGEX: re.Pattern[str] = re.compile(
r"regression\.(diffs|out)|.+\.(?:log|stderr|stdout|filediff|metrics|html|walredo)"
)
@@ -289,7 +291,7 @@ LOGS_STAGING_DATASOURCE_ID = "xHHYY0dVz"
def allure_add_grafana_links(host: str, timeline_id: TimelineId, start_ms: int, end_ms: int):
"""Add links to server logs in Grafana to Allure report"""
links = {}
links: dict[str, str] = {}
# We expect host to be in format like ep-divine-night-159320.us-east-2.aws.neon.build
endpoint_id, region_id, _ = host.split(".", 2)
@@ -341,7 +343,7 @@ def allure_add_grafana_links(host: str, timeline_id: TimelineId, start_ms: int,
def start_in_background(
command: list[str], cwd: Path, log_file_name: str, is_started: Fn
command: list[str], cwd: Path, log_file_name: str, is_started: Callable[[], WaitUntilRet]
) -> subprocess.Popen[bytes]:
"""Starts a process, creates the logfile and redirects stderr and stdout there. Runs the start checks before the process is started, or errors."""
@@ -376,14 +378,11 @@ def start_in_background(
return spawned_process
WaitUntilRet = TypeVar("WaitUntilRet")
def wait_until(
number_of_iterations: int,
interval: float,
func: Callable[[], WaitUntilRet],
show_intermediate_error=False,
show_intermediate_error: bool = False,
) -> WaitUntilRet:
"""
Wait until 'func' returns successfully, without exception. Returns the
@@ -464,7 +463,7 @@ def humantime_to_ms(humantime: str) -> float:
def scan_log_for_errors(input: Iterable[str], allowed_errors: list[str]) -> list[tuple[int, str]]:
# FIXME: this duplicates test_runner/fixtures/pageserver/allowed_errors.py
error_or_warn = re.compile(r"\s(ERROR|WARN)")
errors = []
errors: list[tuple[int, str]] = []
for lineno, line in enumerate(input, start=1):
if len(line) == 0:
continue
@@ -484,7 +483,7 @@ def scan_log_for_errors(input: Iterable[str], allowed_errors: list[str]) -> list
return errors
def assert_no_errors(log_file, service, allowed_errors):
def assert_no_errors(log_file: Path, service: str, allowed_errors: list[str]):
if not log_file.exists():
log.warning(f"Skipping {service} log check: {log_file} does not exist")
return
@@ -504,9 +503,11 @@ class AuxFileStore(str, enum.Enum):
V2 = "v2"
CrossValidation = "cross-validation"
@override
def __repr__(self) -> str:
return f"'aux-{self.value}'"
@override
def __str__(self) -> str:
return f"'aux-{self.value}'"
@@ -525,7 +526,7 @@ def assert_pageserver_backups_equal(left: Path, right: Path, skip_files: set[str
"""
started_at = time.time()
def hash_extracted(reader: Union[IO[bytes], None]) -> bytes:
def hash_extracted(reader: Optional[IO[bytes]]) -> bytes:
assert reader is not None
digest = sha256(usedforsecurity=False)
while True:
@@ -550,7 +551,7 @@ def assert_pageserver_backups_equal(left: Path, right: Path, skip_files: set[str
right_list
), f"unexpected number of files on tar files, {len(left_list)} != {len(right_list)}"
mismatching = set()
mismatching: set[str] = set()
for left_tuple, right_tuple in zip(left_list, right_list):
left_path, left_hash = left_tuple
@@ -575,6 +576,7 @@ class PropagatingThread(threading.Thread):
Simple Thread wrapper with join() propagating the possible exception in the thread.
"""
@override
def run(self):
self.exc = None
try:
@@ -582,7 +584,8 @@ class PropagatingThread(threading.Thread):
except BaseException as e:
self.exc = e
def join(self, timeout=None):
@override
def join(self, timeout: Optional[float] = None) -> Any:
super().join(timeout)
if self.exc:
raise self.exc

View File

@@ -1,7 +1,7 @@
from __future__ import annotations
import threading
from typing import Any, Optional
from typing import TYPE_CHECKING
from fixtures.common_types import TenantId, TimelineId
from fixtures.log_helper import log
@@ -14,6 +14,9 @@ from fixtures.neon_fixtures import (
)
from fixtures.pageserver.utils import wait_for_last_record_lsn
if TYPE_CHECKING:
from typing import Any, Optional
# neon_local doesn't handle creating/modifying endpoints concurrently, so we use a mutex
# to ensure we don't do that: this enables running lots of Workloads in parallel safely.
ENDPOINT_LOCK = threading.Lock()
@@ -100,7 +103,7 @@ class Workload:
self.env, endpoint, self.tenant_id, self.timeline_id, pageserver_id=pageserver_id
)
def write_rows(self, n, pageserver_id: Optional[int] = None, upload: bool = True):
def write_rows(self, n: int, pageserver_id: Optional[int] = None, upload: bool = True):
endpoint = self.endpoint(pageserver_id)
start = self.expect_rows
end = start + n - 1
@@ -121,7 +124,9 @@ class Workload:
else:
return False
def churn_rows(self, n, pageserver_id: Optional[int] = None, upload=True, ingest=True):
def churn_rows(
self, n: int, pageserver_id: Optional[int] = None, upload: bool = True, ingest: bool = True
):
assert self.expect_rows >= n
max_iters = 10

View File

@@ -4,7 +4,7 @@ import enum
import json
import os
import time
from typing import Optional
from typing import TYPE_CHECKING
import pytest
from fixtures.log_helper import log
@@ -16,6 +16,10 @@ from fixtures.pageserver.http import PageserverApiException
from fixtures.utils import wait_until
from fixtures.workload import Workload
if TYPE_CHECKING:
from typing import Optional
AGGRESIVE_COMPACTION_TENANT_CONF = {
# Disable gc and compaction. The test runs compaction manually.
"gc_period": "0s",

View File

@@ -15,7 +15,7 @@ import enum
import os
import re
import time
from typing import Optional
from typing import TYPE_CHECKING
import pytest
from fixtures.common_types import TenantId, TimelineId
@@ -40,6 +40,10 @@ from fixtures.remote_storage import (
from fixtures.utils import wait_until
from fixtures.workload import Workload
if TYPE_CHECKING:
from typing import Optional
# A tenant configuration that is convenient for generating uploads and deletions
# without a large amount of postgres traffic.
TENANT_CONF = {

View File

@@ -23,6 +23,7 @@ from fixtures.remote_storage import s3_storage
from fixtures.utils import wait_until
from fixtures.workload import Workload
from pytest_httpserver import HTTPServer
from typing_extensions import override
from werkzeug.wrappers.request import Request
from werkzeug.wrappers.response import Response
@@ -954,6 +955,7 @@ class PageserverFailpoint(Failure):
self.pageserver_id = pageserver_id
self._mitigate = mitigate
@override
def apply(self, env: NeonEnv):
pageserver = env.get_pageserver(self.pageserver_id)
pageserver.allowed_errors.extend(
@@ -961,19 +963,23 @@ class PageserverFailpoint(Failure):
)
pageserver.http_client().configure_failpoints((self.failpoint, "return(1)"))
@override
def clear(self, env: NeonEnv):
pageserver = env.get_pageserver(self.pageserver_id)
pageserver.http_client().configure_failpoints((self.failpoint, "off"))
if self._mitigate:
env.storage_controller.node_configure(self.pageserver_id, {"availability": "Active"})
@override
def expect_available(self):
return True
@override
def can_mitigate(self):
return self._mitigate
def mitigate(self, env):
@override
def mitigate(self, env: NeonEnv):
env.storage_controller.node_configure(self.pageserver_id, {"availability": "Offline"})
@@ -983,9 +989,11 @@ class StorageControllerFailpoint(Failure):
self.pageserver_id = None
self.action = action
@override
def apply(self, env: NeonEnv):
env.storage_controller.configure_failpoints((self.failpoint, self.action))
@override
def clear(self, env: NeonEnv):
if "panic" in self.action:
log.info("Restarting storage controller after panic")
@@ -994,16 +1002,19 @@ class StorageControllerFailpoint(Failure):
else:
env.storage_controller.configure_failpoints((self.failpoint, "off"))
@override
def expect_available(self):
# Controller panics _do_ leave pageservers available, but our test code relies
# on using the locate API to update configurations in Workload, so we must skip
# these actions when the controller has been panicked.
return "panic" not in self.action
@override
def can_mitigate(self):
return False
def fails_forward(self, env):
@override
def fails_forward(self, env: NeonEnv):
# Edge case: the very last failpoint that simulates a DB connection error, where
# the abort path will fail-forward and result in a complete split.
fail_forward = self.failpoint == "shard-split-post-complete"
@@ -1017,6 +1028,7 @@ class StorageControllerFailpoint(Failure):
return fail_forward
@override
def expect_exception(self):
if "panic" in self.action:
return requests.exceptions.ConnectionError
@@ -1029,18 +1041,22 @@ class NodeKill(Failure):
self.pageserver_id = pageserver_id
self._mitigate = mitigate
@override
def apply(self, env: NeonEnv):
pageserver = env.get_pageserver(self.pageserver_id)
pageserver.stop(immediate=True)
@override
def clear(self, env: NeonEnv):
pageserver = env.get_pageserver(self.pageserver_id)
pageserver.start()
@override
def expect_available(self):
return False
def mitigate(self, env):
@override
def mitigate(self, env: NeonEnv):
env.storage_controller.node_configure(self.pageserver_id, {"availability": "Offline"})
@@ -1059,21 +1075,26 @@ class CompositeFailure(Failure):
self.pageserver_id = f.pageserver_id
break
@override
def apply(self, env: NeonEnv):
for f in self.failures:
f.apply(env)
def clear(self, env):
@override
def clear(self, env: NeonEnv):
for f in self.failures:
f.clear(env)
@override
def expect_available(self):
return all(f.expect_available() for f in self.failures)
def mitigate(self, env):
@override
def mitigate(self, env: NeonEnv):
for f in self.failures:
f.mitigate(env)
@override
def expect_exception(self):
expect = set(f.expect_exception() for f in self.failures)
@@ -1211,7 +1232,7 @@ def test_sharding_split_failures(
assert attached_count == initial_shard_count
def assert_split_done(exclude_ps_id=None) -> None:
def assert_split_done(exclude_ps_id: Optional[int] = None) -> None:
secondary_count = 0
attached_count = 0
for ps in env.pageservers:

View File

@@ -1038,7 +1038,7 @@ def test_storage_controller_tenant_deletion(
)
# Break the compute hook: we are checking that deletion does not depend on the compute hook being available
def break_hook():
def break_hook(_body: Any):
raise RuntimeError("Unexpected call to compute hook")
compute_reconfigure_listener.register_on_notify(break_hook)

View File

@@ -6,7 +6,7 @@ import shutil
import threading
import time
from concurrent.futures import ThreadPoolExecutor
from typing import Optional
from typing import TYPE_CHECKING
import pytest
from fixtures.common_types import TenantId, TenantShardId, TimelineId
@@ -20,6 +20,9 @@ from fixtures.remote_storage import S3Storage, s3_storage
from fixtures.utils import wait_until
from fixtures.workload import Workload
if TYPE_CHECKING:
from typing import Optional
@pytest.mark.parametrize("shard_count", [None, 4])
def test_scrubber_tenant_snapshot(neon_env_builder: NeonEnvBuilder, shard_count: Optional[int]):

View File

@@ -58,6 +58,7 @@ num-integer = { version = "0.1", features = ["i128"] }
num-traits = { version = "0.2", features = ["i128", "libm"] }
once_cell = { version = "1" }
parquet = { version = "53", default-features = false, features = ["zstd"] }
postgres-types = { git = "https://github.com/neondatabase/rust-postgres.git", rev = "20031d7a9ee1addeae6e0968e3899ae6bf01cee2", default-features = false, features = ["with-serde_json-1"] }
prost = { version = "0.13", features = ["prost-derive"] }
rand = { version = "0.8", features = ["small_rng"] }
regex = { version = "1" }
@@ -66,7 +67,7 @@ regex-syntax = { version = "0.8" }
reqwest = { version = "0.12", default-features = false, features = ["blocking", "json", "rustls-tls", "stream"] }
scopeguard = { version = "1" }
serde = { version = "1", features = ["alloc", "derive"] }
serde_json = { version = "1", features = ["raw_value"] }
serde_json = { version = "1", features = ["alloc", "raw_value"] }
sha2 = { version = "0.10", features = ["asm", "oid"] }
signature = { version = "2", default-features = false, features = ["digest", "rand_core", "std"] }
smallvec = { version = "1", default-features = false, features = ["const_new", "write"] }
@@ -76,6 +77,7 @@ sync_wrapper = { version = "0.1", default-features = false, features = ["futures
tikv-jemalloc-sys = { version = "0.5" }
time = { version = "0.3", features = ["macros", "serde-well-known"] }
tokio = { version = "1", features = ["fs", "io-std", "io-util", "macros", "net", "process", "rt-multi-thread", "signal", "test-util"] }
tokio-postgres = { git = "https://github.com/neondatabase/rust-postgres.git", rev = "20031d7a9ee1addeae6e0968e3899ae6bf01cee2", features = ["with-serde_json-1"] }
tokio-stream = { version = "0.1", features = ["net"] }
tokio-util = { version = "0.7", features = ["codec", "compat", "io", "rt"] }
toml_edit = { version = "0.22", features = ["serde"] }