Compare commits

..

1 Commits

Author SHA1 Message Date
Konstantin Knizhnik
3e8cb25e53 Increase range of expected value for working set approximation test 2024-10-16 18:59:19 +03:00
107 changed files with 1699 additions and 2383 deletions

View File

@@ -1100,6 +1100,7 @@ jobs:
run: |
if [[ "$GITHUB_REF_NAME" == "main" ]]; then
gh workflow --repo neondatabase/infra run deploy-dev.yml --ref main -f branch=main -f dockerTag=${{needs.tag.outputs.build-tag}} -f deployPreprodRegion=false
gh workflow --repo neondatabase/azure run deploy.yml -f dockerTag=${{needs.tag.outputs.build-tag}}
elif [[ "$GITHUB_REF_NAME" == "release" ]]; then
gh workflow --repo neondatabase/infra run deploy-dev.yml --ref main \
-f deployPgSniRouter=false \

3
Cargo.lock generated
View File

@@ -4648,10 +4648,9 @@ dependencies = [
"camino-tempfile",
"futures",
"futures-util",
"http-body-util",
"http-types",
"humantime-serde",
"hyper 1.4.1",
"hyper 0.14.30",
"itertools 0.10.5",
"metrics",
"once_cell",

View File

@@ -31,7 +31,7 @@ See developer documentation in [SUMMARY.md](/docs/SUMMARY.md) for more informati
```bash
apt install build-essential libtool libreadline-dev zlib1g-dev flex bison libseccomp-dev \
libssl-dev clang pkg-config libpq-dev cmake postgresql-client protobuf-compiler \
libprotobuf-dev libcurl4-openssl-dev openssl python3-poetry lsof libicu-dev
libcurl4-openssl-dev openssl python3-poetry lsof libicu-dev
```
* On Fedora, these packages are needed:
```bash

View File

@@ -18,14 +18,13 @@ RUN case $DEBIAN_VERSION in \
# Version-specific installs for Bullseye (PG14-PG16):
# The h3_pg extension needs a cmake 3.20+, but Debian bullseye has 3.18.
# Install newer version (3.25) from backports.
# libstdc++-10-dev is required for plv8
bullseye) \
echo "deb http://deb.debian.org/debian bullseye-backports main" > /etc/apt/sources.list.d/bullseye-backports.list; \
VERSION_INSTALLS="cmake/bullseye-backports cmake-data/bullseye-backports libstdc++-10-dev"; \
VERSION_INSTALLS="cmake/bullseye-backports cmake-data/bullseye-backports"; \
;; \
# Version-specific installs for Bookworm (PG17):
bookworm) \
VERSION_INSTALLS="cmake libstdc++-12-dev"; \
VERSION_INSTALLS="cmake"; \
;; \
*) \
echo "Unknown Debian version ${DEBIAN_VERSION}" && exit 1 \
@@ -228,33 +227,18 @@ FROM build-deps AS plv8-build
ARG PG_VERSION
COPY --from=pg-build /usr/local/pgsql/ /usr/local/pgsql/
RUN apt update && \
RUN case "${PG_VERSION}" in "v17") \
echo "v17 extensions are not supported yet. Quit" && exit 0;; \
esac && \
apt update && \
apt install --no-install-recommends -y ninja-build python3-dev libncurses5 binutils clang
# plv8 3.2.3 supports v17
# last release v3.2.3 - Sep 7, 2024
#
# clone the repo instead of downloading the release tarball because plv8 has submodule dependencies
# and the release tarball doesn't include them
#
# Use new version only for v17
# because since v3.2, plv8 doesn't include plcoffee and plls extensions
ENV PLV8_TAG=v3.2.3
RUN case "${PG_VERSION}" in \
"v17") \
export PLV8_TAG=v3.2.3 \
;; \
"v14" | "v15" | "v16") \
export PLV8_TAG=v3.1.10 \
;; \
*) \
echo "unexpected PostgreSQL version" && exit 1 \
;; \
RUN case "${PG_VERSION}" in "v17") \
echo "v17 extensions are not supported yet. Quit" && exit 0;; \
esac && \
git clone --recurse-submodules --depth 1 --branch ${PLV8_TAG} https://github.com/plv8/plv8.git plv8-src && \
tar -czf plv8.tar.gz --exclude .git plv8-src && \
cd plv8-src && \
wget https://github.com/plv8/plv8/archive/refs/tags/v3.1.10.tar.gz -O plv8.tar.gz && \
echo "7096c3290928561f0d4901b7a52794295dc47f6303102fae3f8e42dd575ad97d plv8.tar.gz" | sha256sum --check && \
mkdir plv8-src && cd plv8-src && tar xzf ../plv8.tar.gz --strip-components=1 -C . && \
# generate and copy upgrade scripts
mkdir -p upgrade && ./generate_upgrade.sh 3.1.10 && \
cp upgrade/* /usr/local/pgsql/share/extension/ && \
@@ -264,17 +248,8 @@ RUN case "${PG_VERSION}" in \
find /usr/local/pgsql/ -name "plv8-*.so" | xargs strip && \
# don't break computes with installed old version of plv8
cd /usr/local/pgsql/lib/ && \
case "${PG_VERSION}" in \
"v17") \
ln -s plv8-3.2.3.so plv8-3.1.8.so && \
ln -s plv8-3.2.3.so plv8-3.1.5.so && \
ln -s plv8-3.2.3.so plv8-3.1.10.so \
;; \
"v14" | "v15" | "v16") \
ln -s plv8-3.1.10.so plv8-3.1.5.so && \
ln -s plv8-3.1.10.so plv8-3.1.8.so \
;; \
esac && \
ln -s plv8-3.1.10.so plv8-3.1.5.so && \
ln -s plv8-3.1.10.so plv8-3.1.8.so && \
echo 'trusted = true' >> /usr/local/pgsql/share/extension/plv8.control && \
echo 'trusted = true' >> /usr/local/pgsql/share/extension/plcoffee.control && \
echo 'trusted = true' >> /usr/local/pgsql/share/extension/plls.control
@@ -352,9 +327,6 @@ COPY compute/patches/pgvector.patch /pgvector.patch
# By default, pgvector Makefile uses `-march=native`. We don't want that,
# because we build the images on different machines than where we run them.
# Pass OPTFLAGS="" to remove it.
#
# v17 is not supported yet because of upstream issue
# https://github.com/pgvector/pgvector/issues/669
RUN case "${PG_VERSION}" in "v17") \
echo "v17 extensions are not supported yet. Quit" && exit 0;; \
esac && \
@@ -394,10 +366,11 @@ FROM build-deps AS hypopg-pg-build
ARG PG_VERSION
COPY --from=pg-build /usr/local/pgsql/ /usr/local/pgsql/
# HypoPG 1.4.1 supports v17
# last release 1.4.1 - Apr 28, 2024
RUN wget https://github.com/HypoPG/hypopg/archive/refs/tags/1.4.1.tar.gz -O hypopg.tar.gz && \
echo "9afe6357fd389d8d33fad81703038ce520b09275ec00153c6c89282bcdedd6bc hypopg.tar.gz" | sha256sum --check && \
RUN case "${PG_VERSION}" in "v17") \
echo "v17 extensions are not supported yet. Quit" && exit 0;; \
esac && \
wget https://github.com/HypoPG/hypopg/archive/refs/tags/1.4.0.tar.gz -O hypopg.tar.gz && \
echo "0821011743083226fc9b813c1f2ef5897a91901b57b6bea85a78e466187c6819 hypopg.tar.gz" | sha256sum --check && \
mkdir hypopg-src && cd hypopg-src && tar xzf ../hypopg.tar.gz --strip-components=1 -C . && \
make -j $(getconf _NPROCESSORS_ONLN) PG_CONFIG=/usr/local/pgsql/bin/pg_config && \
make -j $(getconf _NPROCESSORS_ONLN) install PG_CONFIG=/usr/local/pgsql/bin/pg_config && \
@@ -434,9 +407,6 @@ COPY --from=pg-build /usr/local/pgsql/ /usr/local/pgsql/
COPY compute/patches/rum.patch /rum.patch
# maybe version-specific
# support for v17 is unknown
# last release 1.3.13 - Sep 19, 2022
RUN case "${PG_VERSION}" in "v17") \
echo "v17 extensions are not supported yet. Quit" && exit 0;; \
esac && \
@@ -458,10 +428,11 @@ FROM build-deps AS pgtap-pg-build
ARG PG_VERSION
COPY --from=pg-build /usr/local/pgsql/ /usr/local/pgsql/
# pgtap 1.3.3 supports v17
# last release v1.3.3 - Apr 8, 2024
RUN wget https://github.com/theory/pgtap/archive/refs/tags/v1.3.3.tar.gz -O pgtap.tar.gz && \
echo "325ea79d0d2515bce96bce43f6823dcd3effbd6c54cb2a4d6c2384fffa3a14c7 pgtap.tar.gz" | sha256sum --check && \
RUN case "${PG_VERSION}" in "v17") \
echo "v17 extensions are not supported yet. Quit" && exit 0;; \
esac && \
wget https://github.com/theory/pgtap/archive/refs/tags/v1.2.0.tar.gz -O pgtap.tar.gz && \
echo "9c7c3de67ea41638e14f06da5da57bac6f5bd03fea05c165a0ec862205a5c052 pgtap.tar.gz" | sha256sum --check && \
mkdir pgtap-src && cd pgtap-src && tar xzf ../pgtap.tar.gz --strip-components=1 -C . && \
make -j $(getconf _NPROCESSORS_ONLN) PG_CONFIG=/usr/local/pgsql/bin/pg_config && \
make -j $(getconf _NPROCESSORS_ONLN) install PG_CONFIG=/usr/local/pgsql/bin/pg_config && \
@@ -534,10 +505,11 @@ FROM build-deps AS plpgsql-check-pg-build
ARG PG_VERSION
COPY --from=pg-build /usr/local/pgsql/ /usr/local/pgsql/
# plpgsql_check v2.7.11 supports v17
# last release v2.7.11 - Sep 16, 2024
RUN wget https://github.com/okbob/plpgsql_check/archive/refs/tags/v2.7.11.tar.gz -O plpgsql_check.tar.gz && \
echo "208933f8dbe8e0d2628eb3851e9f52e6892b8e280c63700c0f1ce7883625d172 plpgsql_check.tar.gz" | sha256sum --check && \
RUN case "${PG_VERSION}" in "v17") \
echo "v17 extensions are not supported yet. Quit" && exit 0;; \
esac && \
wget https://github.com/okbob/plpgsql_check/archive/refs/tags/v2.5.3.tar.gz -O plpgsql_check.tar.gz && \
echo "6631ec3e7fb3769eaaf56e3dfedb829aa761abf163d13dba354b4c218508e1c0 plpgsql_check.tar.gz" | sha256sum --check && \
mkdir plpgsql_check-src && cd plpgsql_check-src && tar xzf ../plpgsql_check.tar.gz --strip-components=1 -C . && \
make -j $(getconf _NPROCESSORS_ONLN) PG_CONFIG=/usr/local/pgsql/bin/pg_config USE_PGXS=1 && \
make -j $(getconf _NPROCESSORS_ONLN) install PG_CONFIG=/usr/local/pgsql/bin/pg_config USE_PGXS=1 && \
@@ -555,19 +527,18 @@ COPY --from=pg-build /usr/local/pgsql/ /usr/local/pgsql/
ARG PG_VERSION
ENV PATH="/usr/local/pgsql/bin:$PATH"
RUN case "${PG_VERSION}" in \
RUN case "${PG_VERSION}" in "v17") \
echo "v17 extensions are not supported yet. Quit" && exit 0;; \
esac && \
case "${PG_VERSION}" in \
"v14" | "v15") \
export TIMESCALEDB_VERSION=2.10.1 \
export TIMESCALEDB_CHECKSUM=6fca72a6ed0f6d32d2b3523951ede73dc5f9b0077b38450a029a5f411fdb8c73 \
;; \
"v16") \
*) \
export TIMESCALEDB_VERSION=2.13.0 \
export TIMESCALEDB_CHECKSUM=584a351c7775f0e067eaa0e7277ea88cab9077cc4c455cbbf09a5d9723dce95d \
;; \
"v17") \
export TIMESCALEDB_VERSION=2.17.0 \
export TIMESCALEDB_CHECKSUM=155bf64391d3558c42f31ca0e523cfc6252921974f75298c9039ccad1c89811a \
;; \
esac && \
wget https://github.com/timescale/timescaledb/archive/refs/tags/${TIMESCALEDB_VERSION}.tar.gz -O timescaledb.tar.gz && \
echo "${TIMESCALEDB_CHECKSUM} timescaledb.tar.gz" | sha256sum --check && \
@@ -590,8 +561,10 @@ COPY --from=pg-build /usr/local/pgsql/ /usr/local/pgsql/
ARG PG_VERSION
ENV PATH="/usr/local/pgsql/bin:$PATH"
# version-specific, has separate releases for each version
RUN case "${PG_VERSION}" in \
RUN case "${PG_VERSION}" in "v17") \
echo "v17 extensions are not supported yet. Quit" && exit 0;; \
esac && \
case "${PG_VERSION}" in \
"v14") \
export PG_HINT_PLAN_VERSION=14_1_4_1 \
export PG_HINT_PLAN_CHECKSUM=c3501becf70ead27f70626bce80ea401ceac6a77e2083ee5f3ff1f1444ec1ad1 \
@@ -605,8 +578,7 @@ RUN case "${PG_VERSION}" in \
export PG_HINT_PLAN_CHECKSUM=fc85a9212e7d2819d4ae4ac75817481101833c3cfa9f0fe1f980984e12347d00 \
;; \
"v17") \
export PG_HINT_PLAN_VERSION=17_1_7_0 \
export PG_HINT_PLAN_CHECKSUM=06dd306328c67a4248f48403c50444f30959fb61ebe963248dbc2afb396fe600 \
echo "TODO: PG17 pg_hint_plan support" && exit 0 \
;; \
*) \
echo "Export the valid PG_HINT_PLAN_VERSION variable" && exit 1 \
@@ -630,10 +602,6 @@ FROM build-deps AS pg-cron-pg-build
ARG PG_VERSION
COPY --from=pg-build /usr/local/pgsql/ /usr/local/pgsql/
# 1.6.4 available, supports v17
# This is an experimental extension that we do not support on prod yet.
# !Do not remove!
# We set it in shared_preload_libraries and computes will fail to start if library is not found.
ENV PATH="/usr/local/pgsql/bin/:$PATH"
RUN case "${PG_VERSION}" in "v17") \
echo "v17 extensions are not supported yet. Quit" && exit 0;; \
@@ -655,37 +623,23 @@ FROM build-deps AS rdkit-pg-build
ARG PG_VERSION
COPY --from=pg-build /usr/local/pgsql/ /usr/local/pgsql/
RUN apt-get update && \
RUN case "${PG_VERSION}" in "v17") \
echo "v17 extensions are not supported yet. Quit" && exit 0;; \
esac && \
apt-get update && \
apt-get install --no-install-recommends -y \
libboost-iostreams1.74-dev \
libboost-regex1.74-dev \
libboost-serialization1.74-dev \
libboost-system1.74-dev \
libeigen3-dev \
libboost-all-dev
libeigen3-dev
# rdkit Release_2024_09_1 supports v17
# last release Release_2024_09_1 - Sep 27, 2024
#
# Use new version only for v17
# because Release_2024_09_1 has some backward incompatible changes
# https://github.com/rdkit/rdkit/releases/tag/Release_2024_09_1
ENV PATH="/usr/local/pgsql/bin/:/usr/local/pgsql/:$PATH"
RUN case "${PG_VERSION}" in \
"v17") \
export RDKIT_VERSION=Release_2024_09_1 \
export RDKIT_CHECKSUM=034c00d6e9de323506834da03400761ed8c3721095114369d06805409747a60f \
;; \
"v14" | "v15" | "v16") \
export RDKIT_VERSION=Release_2023_03_3 \
export RDKIT_CHECKSUM=bdbf9a2e6988526bfeb8c56ce3cdfe2998d60ac289078e2215374288185e8c8d \
;; \
*) \
echo "unexpected PostgreSQL version" && exit 1 \
;; \
RUN case "${PG_VERSION}" in "v17") \
echo "v17 extensions are not supported yet. Quit" && exit 0;; \
esac && \
wget https://github.com/rdkit/rdkit/archive/refs/tags/${RDKIT_VERSION}.tar.gz -O rdkit.tar.gz && \
echo "${RDKIT_CHECKSUM} rdkit.tar.gz" | sha256sum --check && \
wget https://github.com/rdkit/rdkit/archive/refs/tags/Release_2023_03_3.tar.gz -O rdkit.tar.gz && \
echo "bdbf9a2e6988526bfeb8c56ce3cdfe2998d60ac289078e2215374288185e8c8d rdkit.tar.gz" | sha256sum --check && \
mkdir rdkit-src && cd rdkit-src && tar xzf ../rdkit.tar.gz --strip-components=1 -C . && \
cmake \
-D RDK_BUILD_CAIRO_SUPPORT=OFF \
@@ -724,11 +678,12 @@ FROM build-deps AS pg-uuidv7-pg-build
ARG PG_VERSION
COPY --from=pg-build /usr/local/pgsql/ /usr/local/pgsql/
# not version-specific
# last release v1.6.0 - Oct 9, 2024
ENV PATH="/usr/local/pgsql/bin/:$PATH"
RUN wget https://github.com/fboulnois/pg_uuidv7/archive/refs/tags/v1.6.0.tar.gz -O pg_uuidv7.tar.gz && \
echo "0fa6c710929d003f6ce276a7de7a864e9d1667b2d78be3dc2c07f2409eb55867 pg_uuidv7.tar.gz" | sha256sum --check && \
RUN case "${PG_VERSION}" in "v17") \
echo "v17 extensions are not supported yet. Quit" && exit 0;; \
esac && \
wget https://github.com/fboulnois/pg_uuidv7/archive/refs/tags/v1.0.1.tar.gz -O pg_uuidv7.tar.gz && \
echo "0d0759ab01b7fb23851ecffb0bce27822e1868a4a5819bfd276101c716637a7a pg_uuidv7.tar.gz" | sha256sum --check && \
mkdir pg_uuidv7-src && cd pg_uuidv7-src && tar xzf ../pg_uuidv7.tar.gz --strip-components=1 -C . && \
make -j $(getconf _NPROCESSORS_ONLN) && \
make -j $(getconf _NPROCESSORS_ONLN) install && \
@@ -799,8 +754,6 @@ RUN case "${PG_VERSION}" in \
FROM build-deps AS pg-embedding-pg-build
COPY --from=pg-build /usr/local/pgsql/ /usr/local/pgsql/
# This is our extension, support stopped in favor of pgvector
# TODO: deprecate it
ARG PG_VERSION
ENV PATH="/usr/local/pgsql/bin/:$PATH"
RUN case "${PG_VERSION}" in \
@@ -827,8 +780,6 @@ FROM build-deps AS pg-anon-pg-build
ARG PG_VERSION
COPY --from=pg-build /usr/local/pgsql/ /usr/local/pgsql/
# This is an experimental extension, never got to real production.
# !Do not remove! It can be present in shared_preload_libraries and compute will fail to start if library is not found.
ENV PATH="/usr/local/pgsql/bin/:$PATH"
RUN case "${PG_VERSION}" in "v17") \
echo "postgresql_anonymizer does not yet support PG17" && exit 0;; \
@@ -978,8 +929,8 @@ ARG PG_VERSION
RUN case "${PG_VERSION}" in "v17") \
echo "pg_session_jwt does not yet have a release that supports pg17" && exit 0;; \
esac && \
wget https://github.com/neondatabase/pg_session_jwt/archive/1c79c014c4c225c8684dc24a88369e79b4dbe762.tar.gz -O pg_session_jwt.tar.gz && \
echo "bc04b25626a88580b6fed1b87f45ba0a7ca66dbac003a3ec378a1a21b1456d8b pg_session_jwt.tar.gz" | sha256sum --check && \
wget https://github.com/neondatabase/pg_session_jwt/archive/5aee2625af38213650e1a07ae038fdc427250ee4.tar.gz -O pg_session_jwt.tar.gz && \
echo "5d91b10bc1347d36cffc456cb87bec25047935d6503dc652ca046f04760828e7 pg_session_jwt.tar.gz" | sha256sum --check && \
mkdir pg_session_jwt-src && cd pg_session_jwt-src && tar xzf ../pg_session_jwt.tar.gz --strip-components=1 -C . && \
sed -i 's/pgrx = "=0.11.3"/pgrx = { version = "=0.11.3", features = [ "unsafe-postgres" ] }/g' Cargo.toml && \
cargo pgrx install --release
@@ -995,12 +946,13 @@ FROM build-deps AS wal2json-pg-build
ARG PG_VERSION
COPY --from=pg-build /usr/local/pgsql/ /usr/local/pgsql/
# wal2json wal2json_2_6 supports v17
# last release wal2json_2_6 - Apr 25, 2024
ENV PATH="/usr/local/pgsql/bin/:$PATH"
RUN wget https://github.com/eulerto/wal2json/archive/refs/tags/wal2json_2_6.tar.gz -O wal2json.tar.gz && \
echo "18b4bdec28c74a8fc98a11c72de38378a760327ef8e5e42e975b0029eb96ba0d wal2json.tar.gz" | sha256sum --check && \
mkdir wal2json-src && cd wal2json-src && tar xzf ../wal2json.tar.gz --strip-components=1 -C . && \
RUN case "${PG_VERSION}" in "v17") \
echo "We'll need to update wal2json to 2.6+ for pg17 support" && exit 0;; \
esac && \
wget https://github.com/eulerto/wal2json/archive/refs/tags/wal2json_2_5.tar.gz && \
echo "b516653575541cf221b99cf3f8be9b6821f6dbcfc125675c85f35090f824f00e wal2json_2_5.tar.gz" | sha256sum --check && \
mkdir wal2json-src && cd wal2json-src && tar xzf ../wal2json_2_5.tar.gz --strip-components=1 -C . && \
make -j $(getconf _NPROCESSORS_ONLN) && \
make -j $(getconf _NPROCESSORS_ONLN) install
@@ -1014,11 +966,12 @@ FROM build-deps AS pg-ivm-build
ARG PG_VERSION
COPY --from=pg-build /usr/local/pgsql/ /usr/local/pgsql/
# pg_ivm v1.9 supports v17
# last release v1.9 - Jul 31
ENV PATH="/usr/local/pgsql/bin/:$PATH"
RUN wget https://github.com/sraoss/pg_ivm/archive/refs/tags/v1.9.tar.gz -O pg_ivm.tar.gz && \
echo "59e15722939f274650abf637f315dd723c87073496ca77236b044cb205270d8b pg_ivm.tar.gz" | sha256sum --check && \
RUN case "${PG_VERSION}" in "v17") \
echo "We'll need to update pg_ivm to 1.9+ for pg17 support" && exit 0;; \
esac && \
wget https://github.com/sraoss/pg_ivm/archive/refs/tags/v1.7.tar.gz -O pg_ivm.tar.gz && \
echo "ebfde04f99203c7be4b0e873f91104090e2e83e5429c32ac242d00f334224d5e pg_ivm.tar.gz" | sha256sum --check && \
mkdir pg_ivm-src && cd pg_ivm-src && tar xzf ../pg_ivm.tar.gz --strip-components=1 -C . && \
make -j $(getconf _NPROCESSORS_ONLN) && \
make -j $(getconf _NPROCESSORS_ONLN) install && \
@@ -1034,11 +987,12 @@ FROM build-deps AS pg-partman-build
ARG PG_VERSION
COPY --from=pg-build /usr/local/pgsql/ /usr/local/pgsql/
# should support v17 https://github.com/pgpartman/pg_partman/discussions/693
# last release 5.1.0 Apr 2, 2024
ENV PATH="/usr/local/pgsql/bin/:$PATH"
RUN wget https://github.com/pgpartman/pg_partman/archive/refs/tags/v5.1.0.tar.gz -O pg_partman.tar.gz && \
echo "3e3a27d7ff827295d5c55ef72f07a49062d6204b3cb0b9a048645d6db9f3cb9f pg_partman.tar.gz" | sha256sum --check && \
RUN case "${PG_VERSION}" in "v17") \
echo "pg_partman doesn't support PG17 yet" && exit 0;; \
esac && \
wget https://github.com/pgpartman/pg_partman/archive/refs/tags/v5.0.1.tar.gz -O pg_partman.tar.gz && \
echo "75b541733a9659a6c90dbd40fccb904a630a32880a6e3044d0c4c5f4c8a65525 pg_partman.tar.gz" | sha256sum --check && \
mkdir pg_partman-src && cd pg_partman-src && tar xzf ../pg_partman.tar.gz --strip-components=1 -C . && \
make -j $(getconf _NPROCESSORS_ONLN) && \
make -j $(getconf _NPROCESSORS_ONLN) install && \
@@ -1221,13 +1175,12 @@ RUN rm /usr/local/pgsql/lib/lib*.a
#
#########################################################################################
FROM $REPOSITORY/$IMAGE:$TAG AS sql_exporter_preprocessor
ARG PG_VERSION
USER nonroot
COPY --chown=nonroot compute compute
RUN make PG_VERSION="${PG_VERSION}" -C compute
RUN make -C compute
#########################################################################################
#

View File

@@ -6,15 +6,13 @@ jsonnet_files = $(wildcard \
all: neon_collector.yml neon_collector_autoscaling.yml sql_exporter.yml sql_exporter_autoscaling.yml
neon_collector.yml: $(jsonnet_files)
JSONNET_PATH=jsonnet:etc jsonnet \
JSONNET_PATH=etc jsonnet \
--output-file etc/$@ \
--ext-str pg_version=$(PG_VERSION) \
etc/neon_collector.jsonnet
neon_collector_autoscaling.yml: $(jsonnet_files)
JSONNET_PATH=jsonnet:etc jsonnet \
JSONNET_PATH=etc jsonnet \
--output-file etc/$@ \
--ext-str pg_version=$(PG_VERSION) \
etc/neon_collector_autoscaling.jsonnet
sql_exporter.yml: $(jsonnet_files)

View File

@@ -28,7 +28,7 @@ function(collector_file, application_name='sql_exporter') {
// Collectors (referenced by name) to execute on the target.
// Glob patterns are supported (see <https://pkg.go.dev/path/filepath#Match> for syntax).
collectors: [
'neon_collector',
'neon_collector_autoscaling',
],
},

View File

@@ -1 +0,0 @@
SELECT num_requested AS checkpoints_req FROM pg_stat_checkpointer;

View File

@@ -1,8 +1,3 @@
local neon = import 'neon.libsonnet';
local pg_stat_bgwriter = importstr 'sql_exporter/checkpoints_req.sql';
local pg_stat_checkpointer = importstr 'sql_exporter/checkpoints_req.17.sql';
{
metric_name: 'checkpoints_req',
type: 'gauge',
@@ -11,5 +6,5 @@ local pg_stat_checkpointer = importstr 'sql_exporter/checkpoints_req.17.sql';
values: [
'checkpoints_req',
],
query: if neon.PG_MAJORVERSION_NUM < 17 then pg_stat_bgwriter else pg_stat_checkpointer,
query: importstr 'sql_exporter/checkpoints_req.sql',
}

View File

@@ -1 +0,0 @@
SELECT num_timed AS checkpoints_timed FROM pg_stat_checkpointer;

View File

@@ -1,8 +1,3 @@
local neon = import 'neon.libsonnet';
local pg_stat_bgwriter = importstr 'sql_exporter/checkpoints_req.sql';
local pg_stat_checkpointer = importstr 'sql_exporter/checkpoints_req.17.sql';
{
metric_name: 'checkpoints_timed',
type: 'gauge',
@@ -11,5 +6,5 @@ local pg_stat_checkpointer = importstr 'sql_exporter/checkpoints_req.17.sql';
values: [
'checkpoints_timed',
],
query: if neon.PG_MAJORVERSION_NUM < 17 then pg_stat_bgwriter else pg_stat_checkpointer,
query: importstr 'sql_exporter/checkpoints_timed.sql',
}

View File

@@ -1,16 +0,0 @@
local MIN_SUPPORTED_VERSION = 14;
local MAX_SUPPORTED_VERSION = 17;
local SUPPORTED_VERSIONS = std.range(MIN_SUPPORTED_VERSION, MAX_SUPPORTED_VERSION);
# If we receive the pg_version with a leading "v", ditch it.
local pg_version = std.strReplace(std.extVar('pg_version'), 'v', '');
local pg_version_num = std.parseInt(pg_version);
assert std.setMember(pg_version_num, SUPPORTED_VERSIONS) :
std.format('%s is an unsupported Postgres version: %s',
[pg_version, std.toString(SUPPORTED_VERSIONS)]);
{
PG_MAJORVERSION: pg_version,
PG_MAJORVERSION_NUM: pg_version_num,
}

View File

@@ -25,7 +25,6 @@ use tracing::{debug, error, info, instrument, warn};
use utils::id::{TenantId, TimelineId};
use utils::lsn::Lsn;
use compute_api::privilege::Privilege;
use compute_api::responses::{ComputeMetrics, ComputeStatus};
use compute_api::spec::{ComputeFeature, ComputeMode, ComputeSpec};
use utils::measured_stream::MeasuredReader;
@@ -1368,96 +1367,6 @@ LIMIT 100",
download_size
}
pub async fn install_extension(
&self,
ext_name: &str,
db_name: &str,
ext_version: &str,
) -> Result<String> {
use tokio_postgres::config::Config;
use tokio_postgres::NoTls;
let mut conf = Config::from_str(self.connstr.as_str()).unwrap();
conf.dbname(db_name);
let (db_client, conn) = conf
.connect(NoTls)
.await
.context("Failed to connect to the database")?;
tokio::spawn(conn);
let version_query = "SELECT extversion FROM pg_extension WHERE extname = $1";
let version: Option<String> = db_client
.query_opt(version_query, &[&ext_name])
.await
.with_context(|| format!("Failed to execute query: {}", version_query))?
.map(|row| row.get(0));
// sanitize the inputs as postgres idents.
let ext_name: String = ext_name.to_string().pg_quote();
let ext_version: String = ext_version.to_string().pg_quote();
if let Some(installed_version) = version {
if installed_version == ext_version {
return Ok(installed_version);
}
let query = format!("ALTER EXTENSION {ext_name} UPDATE TO {ext_version}");
db_client
.simple_query(&query)
.await
.context(format!("Failed to execute query: {}", query))?;
} else {
let query =
format!("CREATE EXTENSION IF NOT EXISTS {ext_name} WITH VERSION {ext_version}");
db_client
.simple_query(&query)
.await
.context(format!("Failed to execute query: {}", query))?;
}
Ok(ext_version.to_string())
}
pub async fn set_role_grants(
&self,
db_name: &str,
schema_name: &str,
privileges: &[Privilege],
role_name: &str,
) -> Result<()> {
use tokio_postgres::config::Config;
use tokio_postgres::NoTls;
let mut conf = Config::from_str(self.connstr.as_str()).unwrap();
conf.dbname(db_name);
let (db_client, conn) = conf
.connect(NoTls)
.await
.context("Failed to connect to the database")?;
tokio::spawn(conn);
let query = format!(
"GRANT {} ON SCHEMA {} TO {}",
privileges
.iter()
// should not be quoted as it's part of the command.
// is already sanitized so it's ok
.map(|p| p.as_str())
.collect::<Vec<&'static str>>()
.join(", "),
// quote the schema and role name as identifiers to sanitize them.
schema_name.to_string().pg_quote(),
role_name.to_string().pg_quote(),
);
db_client
.simple_query(&query)
.await
.context(format!("Failed to execute query: {}", query))?;
Ok(())
}
#[tokio::main]
pub async fn prepare_preload_libraries(
&self,

View File

@@ -9,13 +9,8 @@ use crate::catalog::SchemaDumpError;
use crate::catalog::{get_database_schema, get_dbs_and_roles};
use crate::compute::forward_termination_signal;
use crate::compute::{ComputeNode, ComputeState, ParsedSpec};
use compute_api::requests::{
ExtensionInstallRequest, {ConfigurationRequest, SetRoleGrantsRequest},
};
use compute_api::responses::{
ComputeStatus, ComputeStatusResponse, ExtensionInstallResult, GenericAPIError,
SetRoleGrantsResponse,
};
use compute_api::requests::ConfigurationRequest;
use compute_api::responses::{ComputeStatus, ComputeStatusResponse, GenericAPIError};
use anyhow::Result;
use hyper::header::CONTENT_TYPE;
@@ -103,38 +98,6 @@ async fn routes(req: Request<Body>, compute: &Arc<ComputeNode>) -> Response<Body
}
}
(&Method::POST, "/extensions") => {
info!("serving /extensions POST request");
let status = compute.get_status();
if status != ComputeStatus::Running {
let msg = format!(
"invalid compute status for extensions request: {:?}",
status
);
error!(msg);
return Response::new(Body::from(msg));
}
let request = hyper::body::to_bytes(req.into_body()).await.unwrap();
let request = serde_json::from_slice::<ExtensionInstallRequest>(&request).unwrap();
let res = compute
.install_extension(&request.extension, &request.database, &request.version)
.await;
match res {
Ok(version) => render_json(Body::from(
serde_json::to_string(&ExtensionInstallResult {
extension: request.extension,
version,
})
.unwrap(),
)),
Err(e) => {
error!("install_extension failed: {}", e);
render_json_error(&e.to_string(), StatusCode::INTERNAL_SERVER_ERROR)
}
}
}
(&Method::GET, "/info") => {
let num_cpus = num_cpus::get_physical();
info!("serving /info GET request. num_cpus: {}", num_cpus);
@@ -202,46 +165,6 @@ async fn routes(req: Request<Body>, compute: &Arc<ComputeNode>) -> Response<Body
}
}
(&Method::POST, "/grants") => {
info!("serving /grants POST request");
let status = compute.get_status();
if status != ComputeStatus::Running {
let msg = format!(
"invalid compute status for set_role_grants request: {:?}",
status
);
error!(msg);
return Response::new(Body::from(msg));
}
let request = hyper::body::to_bytes(req.into_body()).await.unwrap();
let request = serde_json::from_slice::<SetRoleGrantsRequest>(&request).unwrap();
let res = compute
.set_role_grants(
&request.database,
&request.schema,
&request.privileges,
&request.role,
)
.await;
match res {
Ok(()) => render_json(Body::from(
serde_json::to_string(&SetRoleGrantsResponse {
database: request.database,
schema: request.schema,
role: request.role,
privileges: request.privileges,
})
.unwrap(),
)),
Err(e) => {
error!("set_role_grants failed: {}", e);
Response::new(Body::from(e.to_string()))
}
}
}
// get the list of installed extensions
// currently only used in python tests
// TODO: call it from cplane

View File

@@ -10,7 +10,7 @@ paths:
/status:
get:
tags:
- Info
- Info
summary: Get compute node internal status.
description: ""
operationId: getComputeStatus
@@ -25,7 +25,7 @@ paths:
/metrics.json:
get:
tags:
- Info
- Info
summary: Get compute node startup metrics in JSON format.
description: ""
operationId: getComputeMetricsJSON
@@ -40,7 +40,7 @@ paths:
/insights:
get:
tags:
- Info
- Info
summary: Get current compute insights in JSON format.
description: |
Note, that this doesn't include any historical data.
@@ -56,7 +56,7 @@ paths:
/installed_extensions:
get:
tags:
- Info
- Info
summary: Get installed extensions.
description: ""
operationId: getInstalledExtensions
@@ -70,7 +70,7 @@ paths:
/info:
get:
tags:
- Info
- Info
summary: Get info about the compute pod / VM.
description: ""
operationId: getInfo
@@ -127,38 +127,10 @@ paths:
schema:
$ref: "#/components/schemas/GenericError"
/grants:
post:
tags:
- Grants
summary: Apply grants to the database.
description: ""
operationId: setRoleGrants
requestBody:
description: Grants request.
required: true
content:
application/json:
schema:
$ref: SetRoleGrantsRequest
responses:
200:
description: Grants applied.
content:
application/json:
schema:
$ref: "#/components/schemas/SetRoleGrantsResponse"
500:
description: Error occurred during grants application.
content:
application/json:
schema:
$ref: "#/components/schemas/GenericError"
/check_writability:
post:
tags:
- Check
- Check
summary: Check that we can write new data on this compute.
description: ""
operationId: checkComputeWritability
@@ -172,38 +144,10 @@ paths:
description: Error text or 'true' if check passed.
example: "true"
/extensions:
post:
tags:
- Extensions
summary: Install extension if possible.
description: ""
operationId: installExtension
requestBody:
description: Extension name and database to install it to.
required: true
content:
application/json:
schema:
$ref: "#/components/schemas/ExtensionInstallRequest"
responses:
200:
description: Result from extension installation
content:
application/json:
schema:
$ref: "#/components/schemas/ExtensionInstallResult"
500:
description: Error during extension installation.
content:
application/json:
schema:
$ref: "#/components/schemas/GenericError"
/configure:
post:
tags:
- Configure
- Configure
summary: Perform compute node configuration.
description: |
This is a blocking API endpoint, i.e. it blocks waiting until
@@ -257,7 +201,7 @@ paths:
/extension_server:
post:
tags:
- Extension
- Extension
summary: Download extension from S3 to local folder.
description: ""
operationId: downloadExtension
@@ -286,7 +230,7 @@ paths:
/terminate:
post:
tags:
- Terminate
- Terminate
summary: Terminate Postgres and wait for it to exit
description: ""
operationId: terminate
@@ -425,7 +369,7 @@ components:
moment, when spec was received.
example: "2022-10-12T07:20:50.52Z"
status:
$ref: "#/components/schemas/ComputeStatus"
$ref: '#/components/schemas/ComputeStatus'
last_active:
type: string
description: |
@@ -465,38 +409,6 @@ components:
- configuration
example: running
ExtensionInstallRequest:
type: object
required:
- extension
- database
- version
properties:
extension:
type: string
description: Extension name.
example: "pg_session_jwt"
version:
type: string
description: Version of the extension.
example: "1.0.0"
database:
type: string
description: Database name.
example: "neondb"
ExtensionInstallResult:
type: object
properties:
extension:
description: Name of the extension.
type: string
example: "pg_session_jwt"
version:
description: Version of the extension.
type: string
example: "1.0.0"
InstalledExtensions:
type: object
properties:
@@ -515,60 +427,6 @@ components:
n_databases:
type: integer
SetRoleGrantsRequest:
type: object
required:
- database
- schema
- privileges
- role
properties:
database:
type: string
description: Database name.
example: "neondb"
schema:
type: string
description: Schema name.
example: "public"
privileges:
type: array
items:
type: string
description: List of privileges to set.
example: ["SELECT", "INSERT"]
role:
type: string
description: Role name.
example: "neon"
SetRoleGrantsResponse:
type: object
required:
- database
- schema
- privileges
- role
properties:
database:
type: string
description: Database name.
example: "neondb"
schema:
type: string
description: Schema name.
example: "public"
privileges:
type: array
items:
type: string
description: List of privileges set.
example: ["SELECT", "INSERT"]
role:
type: string
description: Role name.
example: "neon"
#
# Errors
#

View File

@@ -33,7 +33,6 @@ fn list_dbs(client: &mut Client) -> Result<Vec<String>> {
}
/// Connect to every database (see list_dbs above) and get the list of installed extensions.
///
/// Same extension can be installed in multiple databases with different versions,
/// we only keep the highest and lowest version across all databases.
pub async fn get_installed_extensions(connstr: Url) -> Result<InstalledExtensions> {

View File

@@ -1,6 +1,5 @@
#![deny(unsafe_code)]
#![deny(clippy::undocumented_unsafe_blocks)]
pub mod privilege;
pub mod requests;
pub mod responses;
pub mod spec;

View File

@@ -1,35 +0,0 @@
#[derive(Debug, Clone, serde::Deserialize, serde::Serialize)]
#[serde(rename_all = "UPPERCASE")]
pub enum Privilege {
Select,
Insert,
Update,
Delete,
Truncate,
References,
Trigger,
Usage,
Create,
Connect,
Temporary,
Execute,
}
impl Privilege {
pub fn as_str(&self) -> &'static str {
match self {
Privilege::Select => "SELECT",
Privilege::Insert => "INSERT",
Privilege::Update => "UPDATE",
Privilege::Delete => "DELETE",
Privilege::Truncate => "TRUNCATE",
Privilege::References => "REFERENCES",
Privilege::Trigger => "TRIGGER",
Privilege::Usage => "USAGE",
Privilege::Create => "CREATE",
Privilege::Connect => "CONNECT",
Privilege::Temporary => "TEMPORARY",
Privilege::Execute => "EXECUTE",
}
}
}

View File

@@ -1,6 +1,6 @@
//! Structs representing the JSON formats used in the compute_ctl's HTTP API.
use crate::{privilege::Privilege, spec::ComputeSpec};
use crate::spec::ComputeSpec;
use serde::Deserialize;
/// Request of the /configure API
@@ -12,18 +12,3 @@ use serde::Deserialize;
pub struct ConfigurationRequest {
pub spec: ComputeSpec,
}
#[derive(Deserialize, Debug)]
pub struct ExtensionInstallRequest {
pub extension: String,
pub database: String,
pub version: String,
}
#[derive(Deserialize, Debug)]
pub struct SetRoleGrantsRequest {
pub database: String,
pub schema: String,
pub privileges: Vec<Privilege>,
pub role: String,
}

View File

@@ -6,10 +6,7 @@ use std::fmt::Display;
use chrono::{DateTime, Utc};
use serde::{Deserialize, Serialize, Serializer};
use crate::{
privilege::Privilege,
spec::{ComputeSpec, Database, Role},
};
use crate::spec::{ComputeSpec, Database, Role};
#[derive(Serialize, Debug, Deserialize)]
pub struct GenericAPIError {
@@ -171,17 +168,3 @@ pub struct InstalledExtension {
pub struct InstalledExtensions {
pub extensions: Vec<InstalledExtension>,
}
#[derive(Clone, Debug, Default, Serialize)]
pub struct ExtensionInstallResult {
pub extension: String,
pub version: String,
}
#[derive(Clone, Debug, Default, Serialize)]
pub struct SetRoleGrantsResponse {
pub database: String,
pub schema: String,
pub privileges: Vec<Privilege>,
pub role: String,
}

View File

@@ -16,7 +16,7 @@ aws-sdk-s3.workspace = true
bytes.workspace = true
camino = { workspace = true, features = ["serde1"] }
humantime-serde.workspace = true
hyper = { workspace = true, features = ["client"] }
hyper0 = { workspace = true, features = ["stream"] }
futures.workspace = true
serde.workspace = true
serde_json.workspace = true
@@ -36,7 +36,6 @@ azure_storage.workspace = true
azure_storage_blobs.workspace = true
futures-util.workspace = true
http-types.workspace = true
http-body-util.workspace = true
itertools.workspace = true
sync_wrapper = { workspace = true, features = ["futures"] }

View File

@@ -28,15 +28,13 @@ use aws_sdk_s3::{
Client,
};
use aws_smithy_async::rt::sleep::TokioSleep;
use http_body_util::StreamBody;
use http_types::StatusCode;
use aws_smithy_types::{body::SdkBody, DateTime};
use aws_smithy_types::{byte_stream::ByteStream, date_time::ConversionError};
use bytes::Bytes;
use futures::stream::Stream;
use futures_util::StreamExt;
use hyper::body::Frame;
use hyper0::Body;
use scopeguard::ScopeGuard;
use tokio_util::sync::CancellationToken;
use utils::backoff;
@@ -712,8 +710,8 @@ impl RemoteStorage for S3Bucket {
let started_at = start_measuring_requests(kind);
let body = StreamBody::new(from.map(|x| x.map(Frame::data)));
let bytes_stream = ByteStream::new(SdkBody::from_body_1_x(body));
let body = Body::wrap_stream(from);
let bytes_stream = ByteStream::new(SdkBody::from_body_0_4(body));
let upload = self
.client

View File

@@ -720,12 +720,7 @@ async fn timeline_archival_config_handler(
tenant.wait_to_become_active(ACTIVE_TENANT_TIMEOUT).await?;
tenant
.apply_timeline_archival_config(
timeline_id,
request_data.state,
state.broker_client.clone(),
ctx,
)
.apply_timeline_archival_config(timeline_id, request_data.state, ctx)
.await?;
Ok::<_, ApiError>(())
}

View File

@@ -67,7 +67,7 @@ use self::metadata::TimelineMetadata;
use self::mgr::GetActiveTenantError;
use self::mgr::GetTenantError;
use self::remote_timeline_client::upload::upload_index_part;
use self::remote_timeline_client::{RemoteTimelineClient, WaitCompletionError};
use self::remote_timeline_client::RemoteTimelineClient;
use self::timeline::uninit::TimelineCreateGuard;
use self::timeline::uninit::TimelineExclusionError;
use self::timeline::uninit::UninitializedTimeline;
@@ -632,7 +632,7 @@ pub enum TimelineArchivalError {
AlreadyInProgress,
#[error(transparent)]
Other(anyhow::Error),
Other(#[from] anyhow::Error),
}
impl Debug for TimelineArchivalError {
@@ -1554,7 +1554,6 @@ impl Tenant {
async fn unoffload_timeline(
self: &Arc<Self>,
timeline_id: TimelineId,
broker_client: storage_broker::BrokerClientChannel,
ctx: RequestContext,
) -> Result<Arc<Timeline>, TimelineArchivalError> {
info!("unoffloading timeline");
@@ -1603,40 +1602,27 @@ impl Tenant {
"failed to load remote timeline {} for tenant {}",
timeline_id, self.tenant_shard_id
)
})
.map_err(TimelineArchivalError::Other)?;
})?;
let timelines = self.timelines.lock().unwrap();
let Some(timeline) = timelines.get(&timeline_id) else {
if let Some(timeline) = timelines.get(&timeline_id) {
let mut offloaded_timelines = self.timelines_offloaded.lock().unwrap();
if offloaded_timelines.remove(&timeline_id).is_none() {
warn!("timeline already removed from offloaded timelines");
}
info!("timeline unoffloading complete");
Ok(Arc::clone(timeline))
} else {
warn!("timeline not available directly after attach");
return Err(TimelineArchivalError::Other(anyhow::anyhow!(
Err(TimelineArchivalError::Other(anyhow::anyhow!(
"timeline not available directly after attach"
)));
};
let mut offloaded_timelines = self.timelines_offloaded.lock().unwrap();
if offloaded_timelines.remove(&timeline_id).is_none() {
warn!("timeline already removed from offloaded timelines");
)))
}
// Activate the timeline (if it makes sense)
if !(timeline.is_broken() || timeline.is_stopping()) {
let background_jobs_can_start = None;
timeline.activate(
self.clone(),
broker_client.clone(),
background_jobs_can_start,
&ctx,
);
}
info!("timeline unoffloading complete");
Ok(Arc::clone(timeline))
}
pub(crate) async fn apply_timeline_archival_config(
self: &Arc<Self>,
timeline_id: TimelineId,
new_state: TimelineArchivalState,
broker_client: storage_broker::BrokerClientChannel,
ctx: RequestContext,
) -> Result<(), TimelineArchivalError> {
info!("setting timeline archival config");
@@ -1677,29 +1663,18 @@ impl Tenant {
Some(Arc::clone(timeline))
};
// Second part: unoffload timeline (if needed)
// Second part: unarchive timeline (if needed)
let timeline = if let Some(timeline) = timeline_or_unarchive_offloaded {
timeline
} else {
// Turn offloaded timeline into a non-offloaded one
self.unoffload_timeline(timeline_id, broker_client, ctx)
.await?
self.unoffload_timeline(timeline_id, ctx).await?
};
// Third part: upload new timeline archival state and block until it is present in S3
let upload_needed = match timeline
let upload_needed = timeline
.remote_client
.schedule_index_upload_for_timeline_archival_state(new_state)
{
Ok(upload_needed) => upload_needed,
Err(e) => {
if timeline.cancel.is_cancelled() {
return Err(TimelineArchivalError::Cancelled);
} else {
return Err(TimelineArchivalError::Other(e));
}
}
};
.schedule_index_upload_for_timeline_archival_state(new_state)?;
if upload_needed {
info!("Uploading new state");
@@ -1710,14 +1685,7 @@ impl Tenant {
tracing::warn!("reached timeout for waiting on upload queue");
return Err(TimelineArchivalError::Timeout);
};
v.map_err(|e| match e {
WaitCompletionError::NotInitialized(e) => {
TimelineArchivalError::Other(anyhow::anyhow!(e))
}
WaitCompletionError::UploadQueueShutDownOrStopped => {
TimelineArchivalError::Cancelled
}
})?;
v.map_err(|e| TimelineArchivalError::Other(anyhow::anyhow!(e)))?;
}
Ok(())
}
@@ -3368,7 +3336,7 @@ impl Tenant {
/// Populate all Timelines' `GcInfo` with information about their children. We do not set the
/// PITR cutoffs here, because that requires I/O: this is done later, before GC, by [`Self::refresh_gc_info_internal`]
///
/// Subsequently, parent-child relationships are updated incrementally inside [`Timeline::new`] and [`Timeline::drop`].
/// Subsequently, parent-child relationships are updated incrementally during timeline creation/deletion.
fn initialize_gc_info(
&self,
timelines: &std::sync::MutexGuard<HashMap<TimelineId, Arc<Timeline>>>,

View File

@@ -3092,6 +3092,7 @@ impl Timeline {
}
impl Timeline {
#[allow(unknown_lints)] // doc_lazy_continuation is still a new lint
#[allow(clippy::doc_lazy_continuation)]
/// Get the data needed to reconstruct all keys in the provided keyspace
///

View File

@@ -617,34 +617,31 @@ lfc_evict(NRelFileInfo rinfo, ForkNumber forkNum, BlockNumber blkno)
/* remove the page from the cache */
entry->bitmap[chunk_offs >> 5] &= ~(1 << (chunk_offs & (32 - 1)));
if (entry->access_count == 0)
/*
* If the chunk has no live entries, we can position the chunk to be
* recycled first.
*/
if (entry->bitmap[chunk_offs >> 5] == 0)
{
/*
* If the chunk has no live entries, we can position the chunk to be
* recycled first.
*/
if (entry->bitmap[chunk_offs >> 5] == 0)
bool has_remaining_pages = false;
for (int i = 0; i < CHUNK_BITMAP_SIZE; i++)
{
bool has_remaining_pages = false;
for (int i = 0; i < CHUNK_BITMAP_SIZE; i++)
if (entry->bitmap[i] != 0)
{
if (entry->bitmap[i] != 0)
{
has_remaining_pages = true;
break;
}
has_remaining_pages = true;
break;
}
}
/*
* Put the entry at the position that is first to be reclaimed when we
* have no cached pages remaining in the chunk
*/
if (!has_remaining_pages)
{
dlist_delete(&entry->list_node);
dlist_push_head(&lfc_ctl->lru, &entry->list_node);
}
/*
* Put the entry at the position that is first to be reclaimed when we
* have no cached pages remaining in the chunk
*/
if (!has_remaining_pages)
{
dlist_delete(&entry->list_node);
dlist_push_head(&lfc_ctl->lru, &entry->list_node);
}
}

View File

@@ -1,15 +1,16 @@
use super::{ComputeCredentials, ComputeUserInfo};
use crate::{
auth::{self, backend::ComputeCredentialKeys, AuthFlow},
compute,
config::AuthenticationConfig,
context::RequestMonitoring,
control_plane::AuthSecret,
sasl,
stream::{PqStream, Stream},
};
use tokio::io::{AsyncRead, AsyncWrite};
use tracing::{info, warn};
use super::{ComputeCredentials, ComputeUserInfo};
use crate::auth::backend::ComputeCredentialKeys;
use crate::auth::{self, AuthFlow};
use crate::config::AuthenticationConfig;
use crate::context::RequestMonitoring;
use crate::control_plane::AuthSecret;
use crate::stream::{PqStream, Stream};
use crate::{compute, sasl};
pub(super) async fn authenticate(
ctx: &RequestMonitoring,
creds: ComputeUserInfo,

View File

@@ -1,3 +1,15 @@
use crate::{
auth,
cache::Cached,
compute,
config::AuthenticationConfig,
context::RequestMonitoring,
control_plane::{self, provider::NodeInfo, CachedNodeInfo},
error::{ReportableError, UserFacingError},
proxy::connect_compute::ComputeConnectBackend,
stream::PqStream,
waiters,
};
use async_trait::async_trait;
use pq_proto::BeMessage as Be;
use thiserror::Error;
@@ -6,15 +18,6 @@ use tokio_postgres::config::SslMode;
use tracing::{info, info_span};
use super::ComputeCredentialKeys;
use crate::cache::Cached;
use crate::config::AuthenticationConfig;
use crate::context::RequestMonitoring;
use crate::control_plane::provider::NodeInfo;
use crate::control_plane::{self, CachedNodeInfo};
use crate::error::{ReportableError, UserFacingError};
use crate::proxy::connect_compute::ComputeConnectBackend;
use crate::stream::PqStream;
use crate::{auth, compute, waiters};
#[derive(Debug, Error)]
pub(crate) enum WebAuthError {

View File

@@ -1,15 +1,16 @@
use super::{ComputeCredentials, ComputeUserInfo, ComputeUserInfoNoEndpoint};
use crate::{
auth::{self, AuthFlow},
config::AuthenticationConfig,
context::RequestMonitoring,
control_plane::AuthSecret,
intern::EndpointIdInt,
sasl,
stream::{self, Stream},
};
use tokio::io::{AsyncRead, AsyncWrite};
use tracing::{info, warn};
use super::{ComputeCredentials, ComputeUserInfo, ComputeUserInfoNoEndpoint};
use crate::auth::{self, AuthFlow};
use crate::config::AuthenticationConfig;
use crate::context::RequestMonitoring;
use crate::control_plane::AuthSecret;
use crate::intern::EndpointIdInt;
use crate::sasl;
use crate::stream::{self, Stream};
/// Compared to [SCRAM](crate::scram), cleartext password auth saves
/// one round trip and *expensive* computations (>= 4096 HMAC iterations).
/// These properties are benefical for serverless JS workers, so we

View File

@@ -1,22 +1,22 @@
use std::future::Future;
use std::sync::Arc;
use std::time::{Duration, SystemTime};
use std::{
future::Future,
sync::Arc,
time::{Duration, SystemTime},
};
use arc_swap::ArcSwapOption;
use dashmap::DashMap;
use jose_jwk::crypto::KeyInfo;
use serde::de::Visitor;
use serde::{Deserialize, Deserializer};
use serde::{de::Visitor, Deserialize, Deserializer};
use signature::Verifier;
use thiserror::Error;
use tokio::time::Instant;
use crate::auth::backend::ComputeCredentialKeys;
use crate::context::RequestMonitoring;
use crate::control_plane::errors::GetEndpointJwksError;
use crate::http::parse_json_body_with_limit;
use crate::intern::RoleNameInt;
use crate::{EndpointId, RoleName};
use crate::{
auth::backend::ComputeCredentialKeys, context::RequestMonitoring,
control_plane::errors::GetEndpointJwksError, http::parse_json_body_with_limit,
intern::RoleNameInt, EndpointId, RoleName,
};
// TODO(conrad): make these configurable.
const CLOCK_SKEW_LEEWAY: Duration = Duration::from_secs(30);
@@ -381,8 +381,10 @@ fn verify_rsa_signature(
alg: &jose_jwa::Algorithm,
) -> Result<(), JwtError> {
use jose_jwa::{Algorithm, Signing};
use rsa::pkcs1v15::{Signature, VerifyingKey};
use rsa::RsaPublicKey;
use rsa::{
pkcs1v15::{Signature, VerifyingKey},
RsaPublicKey,
};
let key = RsaPublicKey::try_from(key).map_err(JwtError::InvalidRsaKey)?;
@@ -653,9 +655,11 @@ impl From<&jose_jwk::Key> for KeyType {
#[cfg(test)]
mod tests {
use std::future::IntoFuture;
use std::net::SocketAddr;
use std::time::SystemTime;
use crate::RoleName;
use super::*;
use std::{future::IntoFuture, net::SocketAddr, time::SystemTime};
use base64::URL_SAFE_NO_PAD;
use bytes::Bytes;
@@ -668,9 +672,6 @@ mod tests {
use signature::Signer;
use tokio::net::TcpListener;
use super::*;
use crate::RoleName;
fn new_ec_jwk(kid: String) -> (p256::SecretKey, jose_jwk::Jwk) {
let sk = p256::SecretKey::random(&mut OsRng);
let pk = sk.public_key().into();

View File

@@ -1,32 +1,28 @@
use std::net::SocketAddr;
use arc_swap::ArcSwapOption;
use tokio::sync::Semaphore;
use crate::{
auth::backend::jwt::FetchAuthRulesError,
compute::ConnCfg,
context::RequestMonitoring,
control_plane::{
messages::{ColdStartInfo, EndpointJwksResponse, MetricsAuxInfo},
NodeInfo,
},
intern::{BranchIdTag, EndpointIdTag, InternId, ProjectIdTag},
EndpointId,
};
use super::jwt::{AuthRule, FetchAuthRules};
use crate::auth::backend::jwt::FetchAuthRulesError;
use crate::compute::ConnCfg;
use crate::compute_ctl::ComputeCtlApi;
use crate::context::RequestMonitoring;
use crate::control_plane::messages::{ColdStartInfo, EndpointJwksResponse, MetricsAuxInfo};
use crate::control_plane::NodeInfo;
use crate::intern::{BranchIdTag, EndpointIdTag, InternId, ProjectIdTag};
use crate::url::ApiUrl;
use crate::{http, EndpointId};
pub struct LocalBackend {
pub(crate) initialize: Semaphore,
pub(crate) compute_ctl: ComputeCtlApi,
pub(crate) node_info: NodeInfo,
}
impl LocalBackend {
pub fn new(postgres_addr: SocketAddr, compute_ctl: ApiUrl) -> Self {
pub fn new(postgres_addr: SocketAddr) -> Self {
LocalBackend {
initialize: Semaphore::new(1),
compute_ctl: ComputeCtlApi {
api: http::Endpoint::new(compute_ctl, http::new_client()),
},
node_info: NodeInfo {
config: {
let mut cfg = ConnCfg::new();

View File

@@ -17,22 +17,29 @@ use tokio_postgres::config::AuthKeys;
use tracing::{info, warn};
use crate::auth::credentials::check_peer_addr_is_in_list;
use crate::auth::{self, validate_password_and_exchange, AuthError, ComputeUserInfoMaybeEndpoint};
use crate::auth::{validate_password_and_exchange, AuthError};
use crate::cache::Cached;
use crate::config::AuthenticationConfig;
use crate::context::RequestMonitoring;
use crate::control_plane::errors::GetAuthInfoError;
use crate::control_plane::provider::{
CachedAllowedIps, CachedNodeInfo, CachedRoleSecret, ControlPlaneBackend,
};
use crate::control_plane::{self, Api, AuthSecret};
use crate::control_plane::provider::{CachedRoleSecret, ControlPlaneBackend};
use crate::control_plane::AuthSecret;
use crate::intern::EndpointIdInt;
use crate::metrics::Metrics;
use crate::proxy::connect_compute::ComputeConnectBackend;
use crate::proxy::NeonOptions;
use crate::rate_limiter::{BucketRateLimiter, EndpointRateLimiter, RateBucketInfo};
use crate::stream::Stream;
use crate::{scram, stream, EndpointCacheKey, EndpointId, RoleName};
use crate::{
auth::{self, ComputeUserInfoMaybeEndpoint},
config::AuthenticationConfig,
control_plane::{
self,
provider::{CachedAllowedIps, CachedNodeInfo},
Api,
},
stream,
};
use crate::{scram, EndpointCacheKey, EndpointId, RoleName};
/// Alternative to [`std::borrow::Cow`] but doesn't need `T: ToOwned` as we don't need that functionality
pub enum MaybeOwned<'a, T> {
@@ -493,32 +500,34 @@ impl ComputeConnectBackend for Backend<'_, ComputeCredentials> {
#[cfg(test)]
mod tests {
use std::net::IpAddr;
use std::sync::Arc;
use std::time::Duration;
use std::{net::IpAddr, sync::Arc, time::Duration};
use bytes::BytesMut;
use fallible_iterator::FallibleIterator;
use once_cell::sync::Lazy;
use postgres_protocol::authentication::sasl::{ChannelBinding, ScramSha256};
use postgres_protocol::message::backend::Message as PgMessage;
use postgres_protocol::message::frontend;
use postgres_protocol::{
authentication::sasl::{ChannelBinding, ScramSha256},
message::{backend::Message as PgMessage, frontend},
};
use provider::AuthSecret;
use tokio::io::{AsyncRead, AsyncReadExt, AsyncWriteExt};
use super::jwt::JwkCache;
use super::{auth_quirks, AuthRateLimiter};
use crate::auth::backend::MaskedIp;
use crate::auth::{ComputeUserInfoMaybeEndpoint, IpPattern};
use crate::config::AuthenticationConfig;
use crate::context::RequestMonitoring;
use crate::control_plane::provider::{self, CachedAllowedIps, CachedRoleSecret};
use crate::control_plane::{self, CachedNodeInfo};
use crate::proxy::NeonOptions;
use crate::rate_limiter::{EndpointRateLimiter, RateBucketInfo};
use crate::scram::threadpool::ThreadPool;
use crate::scram::ServerSecret;
use crate::stream::{PqStream, Stream};
use crate::{
auth::{backend::MaskedIp, ComputeUserInfoMaybeEndpoint, IpPattern},
config::AuthenticationConfig,
context::RequestMonitoring,
control_plane::{
self,
provider::{self, CachedAllowedIps, CachedRoleSecret},
CachedNodeInfo,
},
proxy::NeonOptions,
rate_limiter::{EndpointRateLimiter, RateBucketInfo},
scram::{threadpool::ThreadPool, ServerSecret},
stream::{PqStream, Stream},
};
use super::{auth_quirks, jwt::JwkCache, AuthRateLimiter};
struct Auth {
ips: Vec<IpPattern>,

View File

@@ -1,22 +1,20 @@
//! User credentials used in authentication.
use std::collections::HashSet;
use std::net::IpAddr;
use std::str::FromStr;
use crate::{
auth::password_hack::parse_endpoint_param,
context::RequestMonitoring,
error::{ReportableError, UserFacingError},
metrics::{Metrics, SniKind},
proxy::NeonOptions,
serverless::SERVERLESS_DRIVER_SNI,
EndpointId, RoleName,
};
use itertools::Itertools;
use pq_proto::StartupMessageParams;
use std::{collections::HashSet, net::IpAddr, str::FromStr};
use thiserror::Error;
use tracing::{info, warn};
use crate::auth::password_hack::parse_endpoint_param;
use crate::context::RequestMonitoring;
use crate::error::{ReportableError, UserFacingError};
use crate::metrics::{Metrics, SniKind};
use crate::proxy::NeonOptions;
use crate::serverless::SERVERLESS_DRIVER_SNI;
use crate::{EndpointId, RoleName};
#[derive(Debug, Error, PartialEq, Eq, Clone)]
pub(crate) enum ComputeUserInfoParseError {
#[error("Parameter '{0}' is missing in startup packet.")]
@@ -251,11 +249,10 @@ fn project_name_valid(name: &str) -> bool {
#[cfg(test)]
mod tests {
use super::*;
use serde_json::json;
use ComputeUserInfoParseError::*;
use super::*;
#[test]
fn parse_bare_minimum() -> anyhow::Result<()> {
// According to postgresql, only `user` should be required.

View File

@@ -1,24 +1,21 @@
//! Main authentication flow.
use std::io;
use std::sync::Arc;
use super::{backend::ComputeCredentialKeys, AuthErrorImpl, PasswordHackPayload};
use crate::{
config::TlsServerEndPoint,
context::RequestMonitoring,
control_plane::AuthSecret,
intern::EndpointIdInt,
sasl,
scram::{self, threadpool::ThreadPool},
stream::{PqStream, Stream},
};
use postgres_protocol::authentication::sasl::{SCRAM_SHA_256, SCRAM_SHA_256_PLUS};
use pq_proto::{BeAuthenticationSaslMessage, BeMessage, BeMessage as Be};
use std::{io, sync::Arc};
use tokio::io::{AsyncRead, AsyncWrite};
use tracing::info;
use super::backend::ComputeCredentialKeys;
use super::{AuthError, PasswordHackPayload};
use crate::config::TlsServerEndPoint;
use crate::context::RequestMonitoring;
use crate::control_plane::AuthSecret;
use crate::intern::EndpointIdInt;
use crate::sasl;
use crate::scram::threadpool::ThreadPool;
use crate::scram::{self};
use crate::stream::{PqStream, Stream};
/// Every authentication selector is supposed to implement this trait.
pub(crate) trait AuthMethod {
/// Any authentication selector should provide initial backend message
@@ -117,14 +114,14 @@ impl<S: AsyncRead + AsyncWrite + Unpin> AuthFlow<'_, S, PasswordHack> {
let msg = self.stream.read_password_message().await?;
let password = msg
.strip_suffix(&[0])
.ok_or(AuthError::MalformedPassword("missing terminator"))?;
.ok_or(AuthErrorImpl::MalformedPassword("missing terminator"))?;
let payload = PasswordHackPayload::parse(password)
// If we ended up here and the payload is malformed, it means that
// the user neither enabled SNI nor resorted to any other method
// for passing the project name we rely on. We should show them
// the most helpful error message and point to the documentation.
.ok_or(AuthError::MissingEndpointName)?;
.ok_or(AuthErrorImpl::MissingEndpointName)?;
Ok(payload)
}
@@ -136,7 +133,7 @@ impl<S: AsyncRead + AsyncWrite + Unpin> AuthFlow<'_, S, CleartextPassword> {
let msg = self.stream.read_password_message().await?;
let password = msg
.strip_suffix(&[0])
.ok_or(AuthError::MalformedPassword("missing terminator"))?;
.ok_or(AuthErrorImpl::MalformedPassword("missing terminator"))?;
let outcome = validate_password_and_exchange(
&self.state.pool,
@@ -166,7 +163,7 @@ impl<S: AsyncRead + AsyncWrite + Unpin> AuthFlow<'_, S, Scram<'_>> {
// Initial client message contains the chosen auth method's name.
let msg = self.stream.read_password_message().await?;
let sasl = sasl::FirstMessage::parse(&msg)
.ok_or(AuthError::MalformedPassword("bad sasl message"))?;
.ok_or(AuthErrorImpl::MalformedPassword("bad sasl message"))?;
// Currently, the only supported SASL method is SCRAM.
if !scram::METHODS.contains(&sasl.method) {

View File

@@ -14,22 +14,22 @@ pub(crate) use password_hack::parse_endpoint_param;
use password_hack::PasswordHackPayload;
mod flow;
use std::io;
use std::net::IpAddr;
pub(crate) use flow::*;
use thiserror::Error;
use tokio::time::error::Elapsed;
use crate::control_plane;
use crate::error::{ReportableError, UserFacingError};
use crate::{
control_plane,
error::{ReportableError, UserFacingError},
};
use std::{io, net::IpAddr};
use thiserror::Error;
/// Convenience wrapper for the authentication error.
pub(crate) type Result<T> = std::result::Result<T, AuthError>;
/// Common authentication error.
#[derive(Debug, Error)]
pub(crate) enum AuthError {
pub(crate) enum AuthErrorImpl {
#[error(transparent)]
Web(#[from] backend::WebAuthError),
@@ -78,70 +78,80 @@ pub(crate) enum AuthError {
ConfirmationTimeout(humantime::Duration),
}
#[derive(Debug, Error)]
#[error(transparent)]
pub(crate) struct AuthError(Box<AuthErrorImpl>);
impl AuthError {
pub(crate) fn bad_auth_method(name: impl Into<Box<str>>) -> Self {
AuthError::BadAuthMethod(name.into())
AuthErrorImpl::BadAuthMethod(name.into()).into()
}
pub(crate) fn auth_failed(user: impl Into<Box<str>>) -> Self {
AuthError::AuthFailed(user.into())
AuthErrorImpl::AuthFailed(user.into()).into()
}
pub(crate) fn ip_address_not_allowed(ip: IpAddr) -> Self {
AuthError::IpAddressNotAllowed(ip)
AuthErrorImpl::IpAddressNotAllowed(ip).into()
}
pub(crate) fn too_many_connections() -> Self {
AuthError::TooManyConnections
AuthErrorImpl::TooManyConnections.into()
}
pub(crate) fn is_auth_failed(&self) -> bool {
matches!(self, AuthError::AuthFailed(_))
matches!(self.0.as_ref(), AuthErrorImpl::AuthFailed(_))
}
pub(crate) fn user_timeout(elapsed: Elapsed) -> Self {
AuthError::UserTimeout(elapsed)
AuthErrorImpl::UserTimeout(elapsed).into()
}
pub(crate) fn confirmation_timeout(timeout: humantime::Duration) -> Self {
AuthError::ConfirmationTimeout(timeout)
AuthErrorImpl::ConfirmationTimeout(timeout).into()
}
}
impl<E: Into<AuthErrorImpl>> From<E> for AuthError {
fn from(e: E) -> Self {
Self(Box::new(e.into()))
}
}
impl UserFacingError for AuthError {
fn to_string_client(&self) -> String {
match self {
Self::Web(e) => e.to_string_client(),
Self::GetAuthInfo(e) => e.to_string_client(),
Self::Sasl(e) => e.to_string_client(),
Self::AuthFailed(_) => self.to_string(),
Self::BadAuthMethod(_) => self.to_string(),
Self::MalformedPassword(_) => self.to_string(),
Self::MissingEndpointName => self.to_string(),
Self::Io(_) => "Internal error".to_string(),
Self::IpAddressNotAllowed(_) => self.to_string(),
Self::TooManyConnections => self.to_string(),
Self::UserTimeout(_) => self.to_string(),
Self::ConfirmationTimeout(_) => self.to_string(),
match self.0.as_ref() {
AuthErrorImpl::Web(e) => e.to_string_client(),
AuthErrorImpl::GetAuthInfo(e) => e.to_string_client(),
AuthErrorImpl::Sasl(e) => e.to_string_client(),
AuthErrorImpl::AuthFailed(_) => self.to_string(),
AuthErrorImpl::BadAuthMethod(_) => self.to_string(),
AuthErrorImpl::MalformedPassword(_) => self.to_string(),
AuthErrorImpl::MissingEndpointName => self.to_string(),
AuthErrorImpl::Io(_) => "Internal error".to_string(),
AuthErrorImpl::IpAddressNotAllowed(_) => self.to_string(),
AuthErrorImpl::TooManyConnections => self.to_string(),
AuthErrorImpl::UserTimeout(_) => self.to_string(),
AuthErrorImpl::ConfirmationTimeout(_) => self.to_string(),
}
}
}
impl ReportableError for AuthError {
fn get_error_kind(&self) -> crate::error::ErrorKind {
match self {
Self::Web(e) => e.get_error_kind(),
Self::GetAuthInfo(e) => e.get_error_kind(),
Self::Sasl(e) => e.get_error_kind(),
Self::AuthFailed(_) => crate::error::ErrorKind::User,
Self::BadAuthMethod(_) => crate::error::ErrorKind::User,
Self::MalformedPassword(_) => crate::error::ErrorKind::User,
Self::MissingEndpointName => crate::error::ErrorKind::User,
Self::Io(_) => crate::error::ErrorKind::ClientDisconnect,
Self::IpAddressNotAllowed(_) => crate::error::ErrorKind::User,
Self::TooManyConnections => crate::error::ErrorKind::RateLimit,
Self::UserTimeout(_) => crate::error::ErrorKind::User,
Self::ConfirmationTimeout(_) => crate::error::ErrorKind::User,
match self.0.as_ref() {
AuthErrorImpl::Web(e) => e.get_error_kind(),
AuthErrorImpl::GetAuthInfo(e) => e.get_error_kind(),
AuthErrorImpl::Sasl(e) => e.get_error_kind(),
AuthErrorImpl::AuthFailed(_) => crate::error::ErrorKind::User,
AuthErrorImpl::BadAuthMethod(_) => crate::error::ErrorKind::User,
AuthErrorImpl::MalformedPassword(_) => crate::error::ErrorKind::User,
AuthErrorImpl::MissingEndpointName => crate::error::ErrorKind::User,
AuthErrorImpl::Io(_) => crate::error::ErrorKind::ClientDisconnect,
AuthErrorImpl::IpAddressNotAllowed(_) => crate::error::ErrorKind::User,
AuthErrorImpl::TooManyConnections => crate::error::ErrorKind::RateLimit,
AuthErrorImpl::UserTimeout(_) => crate::error::ErrorKind::User,
AuthErrorImpl::ConfirmationTimeout(_) => crate::error::ErrorKind::User,
}
}
}

View File

@@ -1,44 +1,41 @@
use std::net::SocketAddr;
use std::pin::pin;
use std::str::FromStr;
use std::sync::Arc;
use std::time::Duration;
use std::{net::SocketAddr, pin::pin, str::FromStr, sync::Arc, time::Duration};
use anyhow::{bail, ensure, Context};
use camino::{Utf8Path, Utf8PathBuf};
use compute_api::spec::LocalProxySpec;
use dashmap::DashMap;
use futures::future::Either;
use proxy::auth::backend::jwt::JwkCache;
use proxy::auth::backend::local::{LocalBackend, JWKS_ROLE_MAP};
use proxy::auth::{self};
use proxy::cancellation::CancellationHandlerMain;
use proxy::config::{self, AuthenticationConfig, HttpConfig, ProxyConfig, RetryConfig};
use proxy::control_plane::locks::ApiLocks;
use proxy::control_plane::messages::{EndpointJwksResponse, JwksSettings};
use proxy::http::health_server::AppMetrics;
use proxy::intern::RoleNameInt;
use proxy::metrics::{Metrics, ThreadPoolMetrics};
use proxy::rate_limiter::{
BucketRateLimiter, EndpointRateLimiter, LeakyBucketConfig, RateBucketInfo,
use proxy::{
auth::{
self,
backend::{
jwt::JwkCache,
local::{LocalBackend, JWKS_ROLE_MAP},
},
},
cancellation::CancellationHandlerMain,
config::{self, AuthenticationConfig, HttpConfig, ProxyConfig, RetryConfig},
control_plane::{
locks::ApiLocks,
messages::{EndpointJwksResponse, JwksSettings},
},
http::health_server::AppMetrics,
intern::RoleNameInt,
metrics::{Metrics, ThreadPoolMetrics},
rate_limiter::{BucketRateLimiter, EndpointRateLimiter, LeakyBucketConfig, RateBucketInfo},
scram::threadpool::ThreadPool,
serverless::{self, cancel_set::CancelSet, GlobalConnPoolOptions},
RoleName,
};
use proxy::scram::threadpool::ThreadPool;
use proxy::serverless::cancel_set::CancelSet;
use proxy::serverless::{self, GlobalConnPoolOptions};
use proxy::url::ApiUrl;
use proxy::RoleName;
project_git_version!(GIT_VERSION);
project_build_tag!(BUILD_TAG);
use clap::Parser;
use tokio::net::TcpListener;
use tokio::sync::Notify;
use tokio::task::JoinSet;
use tokio::{net::TcpListener, sync::Notify, task::JoinSet};
use tokio_util::sync::CancellationToken;
use tracing::{error, info, warn};
use utils::sentry_init::init_sentry;
use utils::{pid_file, project_build_tag, project_git_version};
use utils::{pid_file, project_build_tag, project_git_version, sentry_init::init_sentry};
#[global_allocator]
static GLOBAL: tikv_jemallocator::Jemalloc = tikv_jemallocator::Jemalloc;
@@ -81,10 +78,7 @@ struct LocalProxyCliArgs {
connect_to_compute_retry: String,
/// Address of the postgres server
#[clap(long, default_value = "127.0.0.1:5432")]
postgres: SocketAddr,
/// Address of the compute-ctl api service
#[clap(long, default_value = "http://127.0.0.1:3080/")]
compute_ctl: ApiUrl,
compute: SocketAddr,
/// Path of the local proxy config file
#[clap(long, default_value = "./local_proxy.json")]
config_path: Utf8PathBuf,
@@ -299,7 +293,7 @@ fn build_auth_backend(
args: &LocalProxyCliArgs,
) -> anyhow::Result<&'static auth::Backend<'static, ()>> {
let auth_backend = proxy::auth::Backend::Local(proxy::auth::backend::MaybeOwned::Owned(
LocalBackend::new(args.postgres, args.compute_ctl.clone()),
LocalBackend::new(args.compute),
));
Ok(Box::leak(Box::new(auth_backend)))

View File

@@ -5,23 +5,25 @@
/// the outside. Similar to an ingress controller for HTTPS.
use std::{net::SocketAddr, sync::Arc};
use anyhow::{anyhow, bail, ensure, Context};
use clap::Arg;
use futures::future::Either;
use futures::TryFutureExt;
use itertools::Itertools;
use proxy::config::TlsServerEndPoint;
use proxy::context::RequestMonitoring;
use proxy::metrics::{Metrics, ThreadPoolMetrics};
use proxy::proxy::{copy_bidirectional_client_compute, run_until_cancelled, ErrorSource};
use proxy::stream::{PqStream, Stream};
use rustls::pki_types::PrivateKeyDer;
use tokio::io::{AsyncRead, AsyncWrite};
use tokio::net::TcpListener;
use anyhow::{anyhow, bail, ensure, Context};
use clap::Arg;
use futures::TryFutureExt;
use proxy::stream::{PqStream, Stream};
use tokio::io::{AsyncRead, AsyncWrite};
use tokio_util::sync::CancellationToken;
use utils::{project_git_version, sentry_init::init_sentry};
use tracing::{error, info, Instrument};
use utils::project_git_version;
use utils::sentry_init::init_sentry;
project_git_version!(GIT_VERSION);

View File

@@ -1,8 +1,3 @@
use std::net::SocketAddr;
use std::pin::pin;
use std::sync::Arc;
use anyhow::bail;
use aws_config::environment::EnvironmentVariableCredentialsProvider;
use aws_config::imds::credentials::ImdsCredentialsProvider;
use aws_config::meta::credentials::CredentialsProviderChain;
@@ -12,34 +7,52 @@ use aws_config::provider_config::ProviderConfig;
use aws_config::web_identity_token::WebIdentityTokenCredentialsProvider;
use aws_config::Region;
use futures::future::Either;
use proxy::auth;
use proxy::auth::backend::jwt::JwkCache;
use proxy::auth::backend::{AuthRateLimiter, ConsoleRedirectBackend, MaybeOwned};
use proxy::cancellation::{CancelMap, CancellationHandler};
use proxy::config::{
self, remote_storage_from_toml, AuthenticationConfig, CacheOptions, HttpConfig,
ProjectInfoCacheOptions, ProxyConfig, ProxyProtocolV2,
};
use proxy::auth::backend::AuthRateLimiter;
use proxy::auth::backend::ConsoleRedirectBackend;
use proxy::auth::backend::MaybeOwned;
use proxy::cancellation::CancelMap;
use proxy::cancellation::CancellationHandler;
use proxy::config::remote_storage_from_toml;
use proxy::config::AuthenticationConfig;
use proxy::config::CacheOptions;
use proxy::config::HttpConfig;
use proxy::config::ProjectInfoCacheOptions;
use proxy::config::ProxyProtocolV2;
use proxy::context::parquet::ParquetUploadArgs;
use proxy::control_plane;
use proxy::http;
use proxy::http::health_server::AppMetrics;
use proxy::metrics::Metrics;
use proxy::rate_limiter::{
EndpointRateLimiter, LeakyBucketConfig, RateBucketInfo, WakeComputeRateLimiter,
};
use proxy::rate_limiter::EndpointRateLimiter;
use proxy::rate_limiter::LeakyBucketConfig;
use proxy::rate_limiter::RateBucketInfo;
use proxy::rate_limiter::WakeComputeRateLimiter;
use proxy::redis::cancellation_publisher::RedisPublisherClient;
use proxy::redis::connection_with_credentials_provider::ConnectionWithCredentialsProvider;
use proxy::redis::{elasticache, notifications};
use proxy::redis::elasticache;
use proxy::redis::notifications;
use proxy::scram::threadpool::ThreadPool;
use proxy::serverless::cancel_set::CancelSet;
use proxy::serverless::GlobalConnPoolOptions;
use proxy::{auth, control_plane, http, serverless, usage_metrics};
use proxy::usage_metrics;
use anyhow::bail;
use proxy::config::{self, ProxyConfig};
use proxy::serverless;
use remote_storage::RemoteStorageConfig;
use std::net::SocketAddr;
use std::pin::pin;
use std::sync::Arc;
use tokio::net::TcpListener;
use tokio::sync::Mutex;
use tokio::task::JoinSet;
use tokio_util::sync::CancellationToken;
use tracing::{info, warn, Instrument};
use utils::sentry_init::init_sentry;
use utils::{project_build_tag, project_git_version};
use tracing::info;
use tracing::warn;
use tracing::Instrument;
use utils::{project_build_tag, project_git_version, sentry_init::init_sentry};
project_git_version!(GIT_VERSION);
project_build_tag!(BUILD_TAG);

View File

@@ -1,23 +1,31 @@
use std::convert::Infallible;
use std::sync::atomic::{AtomicBool, Ordering};
use std::sync::Arc;
use std::time::Duration;
use std::{
convert::Infallible,
sync::{
atomic::{AtomicBool, Ordering},
Arc,
},
time::Duration,
};
use dashmap::DashSet;
use redis::streams::{StreamReadOptions, StreamReadReply};
use redis::{AsyncCommands, FromRedisValue, Value};
use redis::{
streams::{StreamReadOptions, StreamReadReply},
AsyncCommands, FromRedisValue, Value,
};
use serde::Deserialize;
use tokio::sync::Mutex;
use tokio_util::sync::CancellationToken;
use tracing::info;
use crate::config::EndpointCacheConfig;
use crate::context::RequestMonitoring;
use crate::intern::{BranchIdInt, EndpointIdInt, ProjectIdInt};
use crate::metrics::{Metrics, RedisErrors, RedisEventsCount};
use crate::rate_limiter::GlobalRateLimiter;
use crate::redis::connection_with_credentials_provider::ConnectionWithCredentialsProvider;
use crate::EndpointId;
use crate::{
config::EndpointCacheConfig,
context::RequestMonitoring,
intern::{BranchIdInt, EndpointIdInt, ProjectIdInt},
metrics::{Metrics, RedisErrors, RedisEventsCount},
rate_limiter::GlobalRateLimiter,
redis::connection_with_credentials_provider::ConnectionWithCredentialsProvider,
EndpointId,
};
#[derive(Deserialize, Debug, Clone)]
pub(crate) struct ControlPlaneEventKey {

View File

@@ -1,8 +1,9 @@
use std::collections::HashSet;
use std::convert::Infallible;
use std::sync::atomic::AtomicU64;
use std::sync::Arc;
use std::time::Duration;
use std::{
collections::HashSet,
convert::Infallible,
sync::{atomic::AtomicU64, Arc},
time::Duration,
};
use async_trait::async_trait;
use dashmap::DashMap;
@@ -12,12 +13,15 @@ use tokio::sync::Mutex;
use tokio::time::Instant;
use tracing::{debug, info};
use crate::{
auth::IpPattern,
config::ProjectInfoCacheOptions,
control_plane::AuthSecret,
intern::{EndpointIdInt, ProjectIdInt, RoleNameInt},
EndpointId, RoleName,
};
use super::{Cache, Cached};
use crate::auth::IpPattern;
use crate::config::ProjectInfoCacheOptions;
use crate::control_plane::AuthSecret;
use crate::intern::{EndpointIdInt, ProjectIdInt, RoleNameInt};
use crate::{EndpointId, RoleName};
#[async_trait]
pub(crate) trait ProjectInfoCache {
@@ -367,8 +371,7 @@ impl Cache for ProjectInfoCacheImpl {
#[cfg(test)]
mod tests {
use super::*;
use crate::scram::ServerSecret;
use crate::ProjectId;
use crate::{scram::ServerSecret, ProjectId};
#[tokio::test]
async fn test_project_info_cache_settings() {

View File

@@ -1,6 +1,9 @@
use std::borrow::Borrow;
use std::hash::Hash;
use std::time::{Duration, Instant};
use std::{
borrow::Borrow,
hash::Hash,
time::{Duration, Instant},
};
use tracing::debug;
// This seems to make more sense than `lru` or `cached`:
//
@@ -12,10 +15,8 @@ use std::time::{Duration, Instant};
//
// On the other hand, `hashlink` has good download stats and appears to be maintained.
use hashlink::{linked_hash_map::RawEntryMut, LruCache};
use tracing::debug;
use super::common::Cached;
use super::{timed_lru, Cache};
use super::{common::Cached, timed_lru, Cache};
/// An implementation of timed LRU cache with fixed capacity.
/// Key properties:

View File

@@ -1,8 +1,6 @@
use std::net::SocketAddr;
use std::sync::Arc;
use dashmap::DashMap;
use pq_proto::CancelKeyData;
use std::{net::SocketAddr, sync::Arc};
use thiserror::Error;
use tokio::net::TcpStream;
use tokio::sync::Mutex;
@@ -10,10 +8,12 @@ use tokio_postgres::{CancelToken, NoTls};
use tracing::info;
use uuid::Uuid;
use crate::error::ReportableError;
use crate::metrics::{CancellationRequest, CancellationSource, Metrics};
use crate::redis::cancellation_publisher::{
CancellationPublisher, CancellationPublisherMut, RedisPublisherClient,
use crate::{
error::ReportableError,
metrics::{CancellationRequest, CancellationSource, Metrics},
redis::cancellation_publisher::{
CancellationPublisher, CancellationPublisherMut, RedisPublisherClient,
},
};
pub type CancelMap = Arc<DashMap<CancelKeyData, Option<CancelClosure>>>;

View File

@@ -1,31 +1,25 @@
use std::io;
use std::net::SocketAddr;
use std::sync::Arc;
use std::time::Duration;
use crate::{
auth::parse_endpoint_param,
cancellation::CancelClosure,
context::RequestMonitoring,
control_plane::{errors::WakeComputeError, messages::MetricsAuxInfo, provider::ApiLockError},
error::{ReportableError, UserFacingError},
metrics::{Metrics, NumDbConnectionsGuard},
proxy::neon_option,
Host,
};
use futures::{FutureExt, TryFutureExt};
use itertools::Itertools;
use once_cell::sync::OnceCell;
use pq_proto::StartupMessageParams;
use rustls::client::danger::ServerCertVerifier;
use rustls::pki_types::InvalidDnsNameError;
use rustls::{client::danger::ServerCertVerifier, pki_types::InvalidDnsNameError};
use std::{io, net::SocketAddr, sync::Arc, time::Duration};
use thiserror::Error;
use tokio::net::TcpStream;
use tokio_postgres::tls::MakeTlsConnect;
use tokio_postgres_rustls::MakeRustlsConnect;
use tracing::{error, info, warn};
use crate::auth::parse_endpoint_param;
use crate::cancellation::CancelClosure;
use crate::context::RequestMonitoring;
use crate::control_plane::errors::WakeComputeError;
use crate::control_plane::messages::MetricsAuxInfo;
use crate::control_plane::provider::ApiLockError;
use crate::error::{ReportableError, UserFacingError};
use crate::metrics::{Metrics, NumDbConnectionsGuard};
use crate::proxy::neon_option;
use crate::Host;
pub const COULD_NOT_CONNECT: &str = "Couldn't connect to compute node";
#[derive(Debug, Error)]

View File

@@ -1,101 +0,0 @@
use compute_api::responses::GenericAPIError;
use hyper::{Method, StatusCode};
use serde::de::DeserializeOwned;
use serde::{Deserialize, Serialize};
use thiserror::Error;
use crate::url::ApiUrl;
use crate::{http, DbName, RoleName};
pub struct ComputeCtlApi {
pub(crate) api: http::Endpoint,
}
#[derive(Serialize, Debug)]
pub struct ExtensionInstallRequest {
pub extension: &'static str,
pub database: DbName,
pub version: &'static str,
}
#[derive(Serialize, Debug)]
pub struct SetRoleGrantsRequest {
pub database: DbName,
pub schema: &'static str,
pub privileges: Vec<Privilege>,
pub role: RoleName,
}
#[derive(Clone, Debug, Deserialize)]
pub struct ExtensionInstallResponse {}
#[derive(Clone, Debug, Deserialize)]
pub struct SetRoleGrantsResponse {}
#[derive(Debug, Serialize, Deserialize, Clone, Copy)]
#[serde(rename_all = "UPPERCASE")]
pub enum Privilege {
Usage,
}
#[derive(Error, Debug)]
pub enum ComputeCtlError {
#[error("connection error: {0}")]
ConnectionError(#[source] reqwest_middleware::Error),
#[error("request error [{status}]: {body:?}")]
RequestError {
status: StatusCode,
body: Option<GenericAPIError>,
},
#[error("response parsing error: {0}")]
ResonseError(#[source] reqwest::Error),
}
impl ComputeCtlApi {
pub async fn install_extension(
&self,
req: &ExtensionInstallRequest,
) -> Result<ExtensionInstallResponse, ComputeCtlError> {
self.generic_request(req, Method::POST, |url| {
url.path_segments_mut().push("extensions");
})
.await
}
pub async fn grant_role(
&self,
req: &SetRoleGrantsRequest,
) -> Result<SetRoleGrantsResponse, ComputeCtlError> {
self.generic_request(req, Method::POST, |url| {
url.path_segments_mut().push("grants");
})
.await
}
async fn generic_request<Req, Resp>(
&self,
req: &Req,
method: Method,
url: impl for<'a> FnOnce(&'a mut ApiUrl),
) -> Result<Resp, ComputeCtlError>
where
Req: Serialize,
Resp: DeserializeOwned,
{
let resp = self
.api
.request_with_url(method, url)
.json(req)
.send()
.await
.map_err(ComputeCtlError::ConnectionError)?;
let status = resp.status();
if status.is_client_error() || status.is_server_error() {
let body = resp.json().await.ok();
return Err(ComputeCtlError::RequestError { status, body });
}
resp.json().await.map_err(ComputeCtlError::ResonseError)
}
}

View File

@@ -1,27 +1,29 @@
use std::collections::{HashMap, HashSet};
use std::str::FromStr;
use std::sync::Arc;
use std::time::Duration;
use crate::{
auth::backend::{jwt::JwkCache, AuthRateLimiter},
control_plane::locks::ApiLocks,
rate_limiter::{RateBucketInfo, RateLimitAlgorithm, RateLimiterConfig},
scram::threadpool::ThreadPool,
serverless::{cancel_set::CancelSet, GlobalConnPoolOptions},
Host,
};
use anyhow::{bail, ensure, Context, Ok};
use clap::ValueEnum;
use itertools::Itertools;
use remote_storage::RemoteStorageConfig;
use rustls::crypto::ring::sign;
use rustls::pki_types::{CertificateDer, PrivateKeyDer};
use rustls::{
crypto::ring::sign,
pki_types::{CertificateDer, PrivateKeyDer},
};
use sha2::{Digest, Sha256};
use std::{
collections::{HashMap, HashSet},
str::FromStr,
sync::Arc,
time::Duration,
};
use tracing::{error, info};
use x509_parser::oid_registry;
use crate::auth::backend::jwt::JwkCache;
use crate::auth::backend::AuthRateLimiter;
use crate::control_plane::locks::ApiLocks;
use crate::rate_limiter::{RateBucketInfo, RateLimitAlgorithm, RateLimiterConfig};
use crate::scram::threadpool::ThreadPool;
use crate::serverless::cancel_set::CancelSet;
use crate::serverless::GlobalConnPoolOptions;
use crate::Host;
pub struct ProxyConfig {
pub tls_config: Option<TlsConfig>,
pub metric_collection: Option<MetricCollectionConfig>,
@@ -690,9 +692,10 @@ impl FromStr for ConcurrencyLockOptions {
#[cfg(test)]
mod tests {
use super::*;
use crate::rate_limiter::Aimd;
use super::*;
#[test]
fn test_parse_cache_options() -> anyhow::Result<()> {
let CacheOptions { size, ttl } = "size=4096,ttl=5min".parse()?;

View File

@@ -1,22 +1,25 @@
use std::sync::Arc;
use crate::auth::backend::ConsoleRedirectBackend;
use crate::config::{ProxyConfig, ProxyProtocolV2};
use crate::proxy::{
prepare_client_connection, run_until_cancelled, ClientRequestError, ErrorSource,
};
use crate::{
cancellation::{CancellationHandlerMain, CancellationHandlerMainInternal},
context::RequestMonitoring,
error::ReportableError,
metrics::{Metrics, NumClientConnectionsGuard},
protocol2::read_proxy_protocol,
proxy::handshake::{handshake, HandshakeData},
};
use futures::TryFutureExt;
use std::sync::Arc;
use tokio::io::{AsyncRead, AsyncWrite, AsyncWriteExt};
use tokio_util::sync::CancellationToken;
use tracing::{error, info, Instrument};
use crate::auth::backend::ConsoleRedirectBackend;
use crate::cancellation::{CancellationHandlerMain, CancellationHandlerMainInternal};
use crate::config::{ProxyConfig, ProxyProtocolV2};
use crate::context::RequestMonitoring;
use crate::error::ReportableError;
use crate::metrics::{Metrics, NumClientConnectionsGuard};
use crate::protocol2::read_proxy_protocol;
use crate::proxy::connect_compute::{connect_to_compute, TcpMechanism};
use crate::proxy::handshake::{handshake, HandshakeData};
use crate::proxy::passthrough::ProxyPassthrough;
use crate::proxy::{
prepare_client_connection, run_until_cancelled, ClientRequestError, ErrorSource,
connect_compute::{connect_to_compute, TcpMechanism},
passthrough::ProxyPassthrough,
};
pub async fn task_main(

View File

@@ -1,25 +1,24 @@
//! Connection request monitoring contexts
use std::net::IpAddr;
use chrono::Utc;
use once_cell::sync::OnceCell;
use pq_proto::StartupMessageParams;
use smol_str::SmolStr;
use std::net::IpAddr;
use tokio::sync::mpsc;
use tracing::field::display;
use tracing::{debug, info, info_span, Span};
use tracing::{debug, field::display, info, info_span, Span};
use try_lock::TryLock;
use uuid::Uuid;
use self::parquet::RequestData;
use crate::control_plane::messages::{ColdStartInfo, MetricsAuxInfo};
use crate::error::ErrorKind;
use crate::intern::{BranchIdInt, ProjectIdInt};
use crate::metrics::{
ConnectOutcome, InvalidEndpointsGroup, LatencyTimer, Metrics, Protocol, Waiting,
use crate::{
control_plane::messages::{ColdStartInfo, MetricsAuxInfo},
error::ErrorKind,
intern::{BranchIdInt, ProjectIdInt},
metrics::{ConnectOutcome, InvalidEndpointsGroup, LatencyTimer, Metrics, Protocol, Waiting},
DbName, EndpointId, RoleName,
};
use crate::{DbName, EndpointId, RoleName};
use self::parquet::RequestData;
pub mod parquet;

View File

@@ -1,28 +1,29 @@
use std::sync::Arc;
use std::time::SystemTime;
use std::{sync::Arc, time::SystemTime};
use anyhow::Context;
use bytes::buf::Writer;
use bytes::{BufMut, BytesMut};
use bytes::{buf::Writer, BufMut, BytesMut};
use chrono::{Datelike, Timelike};
use futures::{Stream, StreamExt};
use parquet::basic::Compression;
use parquet::file::metadata::RowGroupMetaDataPtr;
use parquet::file::properties::{WriterProperties, WriterPropertiesPtr, DEFAULT_PAGE_SIZE};
use parquet::file::writer::SerializedFileWriter;
use parquet::record::RecordWriter;
use parquet::{
basic::Compression,
file::{
metadata::RowGroupMetaDataPtr,
properties::{WriterProperties, WriterPropertiesPtr, DEFAULT_PAGE_SIZE},
writer::SerializedFileWriter,
},
record::RecordWriter,
};
use pq_proto::StartupMessageParams;
use remote_storage::{GenericRemoteStorage, RemotePath, RemoteStorageConfig, TimeoutOrCancel};
use serde::ser::SerializeMap;
use tokio::sync::mpsc;
use tokio::time;
use tokio::{sync::mpsc, time};
use tokio_util::sync::CancellationToken;
use tracing::{debug, info, Span};
use utils::backoff;
use crate::{config::remote_storage_from_toml, context::LOG_CHAN_DISCONNECT};
use super::{RequestMonitoringInner, LOG_CHAN};
use crate::config::remote_storage_from_toml;
use crate::context::LOG_CHAN_DISCONNECT;
#[derive(clap::Args, Clone, Debug)]
pub struct ParquetUploadArgs {
@@ -406,26 +407,26 @@ async fn upload_parquet(
#[cfg(test)]
mod tests {
use std::net::Ipv4Addr;
use std::num::NonZeroUsize;
use std::sync::Arc;
use std::{net::Ipv4Addr, num::NonZeroUsize, sync::Arc};
use camino::Utf8Path;
use clap::Parser;
use futures::{Stream, StreamExt};
use itertools::Itertools;
use parquet::basic::{Compression, ZstdLevel};
use parquet::file::properties::{WriterProperties, DEFAULT_PAGE_SIZE};
use parquet::file::reader::FileReader;
use parquet::file::serialized_reader::SerializedFileReader;
use rand::rngs::StdRng;
use rand::{Rng, SeedableRng};
use parquet::{
basic::{Compression, ZstdLevel},
file::{
properties::{WriterProperties, DEFAULT_PAGE_SIZE},
reader::FileReader,
serialized_reader::SerializedFileReader,
},
};
use rand::{rngs::StdRng, Rng, SeedableRng};
use remote_storage::{
GenericRemoteStorage, RemoteStorageConfig, RemoteStorageKind, S3Config,
DEFAULT_MAX_KEYS_PER_LIST_RESPONSE, DEFAULT_REMOTE_STORAGE_S3_CONCURRENCY_LIMIT,
};
use tokio::sync::mpsc;
use tokio::time;
use tokio::{sync::mpsc, time};
use walkdir::WalkDir;
use super::{worker_inner, ParquetConfig, ParquetUploadArgs, RequestData};

View File

@@ -1,9 +1,9 @@
use std::fmt::{self, Display};
use measured::FixedCardinalityLabel;
use serde::{Deserialize, Serialize};
use std::fmt::{self, Display};
use crate::auth::IpPattern;
use crate::intern::{BranchIdInt, EndpointIdInt, ProjectIdInt, RoleNameInt};
use crate::proxy::retry::CouldRetry;
@@ -362,9 +362,8 @@ pub struct JwksSettings {
#[cfg(test)]
mod tests {
use serde_json::json;
use super::*;
use serde_json::json;
fn dummy_aux() -> serde_json::Value {
json!({

View File

@@ -1,16 +1,16 @@
use std::convert::Infallible;
use crate::{
control_plane::messages::{DatabaseInfo, KickSession},
waiters::{self, Waiter, Waiters},
};
use anyhow::Context;
use once_cell::sync::Lazy;
use postgres_backend::{AuthType, PostgresBackend, PostgresBackendTCP, QueryError};
use pq_proto::{BeMessage, SINGLE_COL_ROWDESC};
use std::convert::Infallible;
use tokio::net::{TcpListener, TcpStream};
use tokio_util::sync::CancellationToken;
use tracing::{error, info, info_span, Instrument};
use crate::control_plane::messages::{DatabaseInfo, KickSession};
use crate::waiters::{self, Waiter, Waiters};
static CPLANE_WAITERS: Lazy<Waiters<ComputeReady>> = Lazy::new(Default::default);
/// Give caller an opportunity to wait for the cloud's reply.

View File

@@ -1,29 +1,28 @@
//! Mock console backend which relies on a user-provided postgres instance.
use std::str::FromStr;
use std::sync::Arc;
use super::{
errors::{ApiError, GetAuthInfoError, WakeComputeError},
AuthInfo, AuthSecret, CachedNodeInfo, NodeInfo,
};
use crate::{
auth::backend::jwt::AuthRule, context::RequestMonitoring,
control_plane::errors::GetEndpointJwksError, intern::RoleNameInt, RoleName,
};
use crate::{auth::backend::ComputeUserInfo, compute, error::io_error, scram, url::ApiUrl};
use crate::{auth::IpPattern, cache::Cached};
use crate::{
control_plane::{
messages::MetricsAuxInfo,
provider::{CachedAllowedIps, CachedRoleSecret},
},
BranchId, EndpointId, ProjectId,
};
use futures::TryFutureExt;
use std::{str::FromStr, sync::Arc};
use thiserror::Error;
use tokio_postgres::config::SslMode;
use tokio_postgres::Client;
use tokio_postgres::{config::SslMode, Client};
use tracing::{error, info, info_span, warn, Instrument};
use super::errors::{ApiError, GetAuthInfoError, WakeComputeError};
use super::{AuthInfo, AuthSecret, CachedNodeInfo, NodeInfo};
use crate::auth::backend::jwt::AuthRule;
use crate::auth::backend::ComputeUserInfo;
use crate::auth::IpPattern;
use crate::cache::Cached;
use crate::context::RequestMonitoring;
use crate::control_plane::errors::GetEndpointJwksError;
use crate::control_plane::messages::MetricsAuxInfo;
use crate::control_plane::provider::{CachedAllowedIps, CachedRoleSecret};
use crate::error::io_error;
use crate::intern::RoleNameInt;
use crate::url::ApiUrl;
use crate::{compute, scram, BranchId, EndpointId, ProjectId, RoleName};
#[derive(Debug, Error)]
enum MockApiError {
#[error("Failed to read password: {0}")]

View File

@@ -2,36 +2,39 @@
pub mod mock;
pub mod neon;
use std::hash::Hash;
use std::sync::Arc;
use std::time::Duration;
use super::messages::{ControlPlaneError, MetricsAuxInfo};
use crate::{
auth::{
backend::{
jwt::{AuthRule, FetchAuthRules, FetchAuthRulesError},
ComputeCredentialKeys, ComputeUserInfo,
},
IpPattern,
},
cache::{endpoints::EndpointsCache, project_info::ProjectInfoCacheImpl, Cached, TimedLru},
compute,
config::{CacheOptions, EndpointCacheConfig, ProjectInfoCacheOptions},
context::RequestMonitoring,
error::ReportableError,
intern::ProjectIdInt,
metrics::ApiLockMetrics,
rate_limiter::{DynamicLimiter, Outcome, RateLimiterConfig, Token},
scram, EndpointCacheKey, EndpointId,
};
use dashmap::DashMap;
use std::{hash::Hash, sync::Arc, time::Duration};
use tokio::time::Instant;
use tracing::info;
use super::messages::{ControlPlaneError, MetricsAuxInfo};
use crate::auth::backend::jwt::{AuthRule, FetchAuthRules, FetchAuthRulesError};
use crate::auth::backend::{ComputeCredentialKeys, ComputeUserInfo};
use crate::auth::IpPattern;
use crate::cache::endpoints::EndpointsCache;
use crate::cache::project_info::ProjectInfoCacheImpl;
use crate::cache::{Cached, TimedLru};
use crate::config::{CacheOptions, EndpointCacheConfig, ProjectInfoCacheOptions};
use crate::context::RequestMonitoring;
use crate::error::ReportableError;
use crate::intern::ProjectIdInt;
use crate::metrics::ApiLockMetrics;
use crate::rate_limiter::{DynamicLimiter, Outcome, RateLimiterConfig, Token};
use crate::{compute, scram, EndpointCacheKey, EndpointId};
pub(crate) mod errors {
use crate::{
control_plane::messages::{self, ControlPlaneError, Reason},
error::{io_error, ErrorKind, ReportableError, UserFacingError},
proxy::retry::CouldRetry,
};
use thiserror::Error;
use super::ApiLockError;
use crate::control_plane::messages::{self, ControlPlaneError, Reason};
use crate::error::{io_error, ErrorKind, ReportableError, UserFacingError};
use crate::proxy::retry::CouldRetry;
/// A go-to error message which doesn't leak any detail.
pub(crate) const REQUEST_FAILED: &str = "Console request failed";

View File

@@ -1,30 +1,30 @@
//! Production console backend.
use std::sync::Arc;
use std::time::Duration;
use ::http::header::AUTHORIZATION;
use ::http::HeaderName;
use futures::TryFutureExt;
use tokio::time::Instant;
use tokio_postgres::config::SslMode;
use tracing::{debug, info, info_span, warn, Instrument};
use super::super::messages::{ControlPlaneError, GetRoleSecret, WakeCompute};
use super::errors::{ApiError, GetAuthInfoError, WakeComputeError};
use super::{
super::messages::{ControlPlaneError, GetRoleSecret, WakeCompute},
errors::{ApiError, GetAuthInfoError, WakeComputeError},
ApiCaches, ApiLocks, AuthInfo, AuthSecret, CachedAllowedIps, CachedNodeInfo, CachedRoleSecret,
NodeInfo,
};
use crate::auth::backend::jwt::AuthRule;
use crate::auth::backend::ComputeUserInfo;
use crate::cache::Cached;
use crate::context::RequestMonitoring;
use crate::control_plane::errors::GetEndpointJwksError;
use crate::control_plane::messages::{ColdStartInfo, EndpointJwksResponse, Reason};
use crate::metrics::{CacheOutcome, Metrics};
use crate::rate_limiter::WakeComputeRateLimiter;
use crate::{compute, http, scram, EndpointCacheKey, EndpointId};
use crate::{
auth::backend::{jwt::AuthRule, ComputeUserInfo},
compute,
control_plane::{
errors::GetEndpointJwksError,
messages::{ColdStartInfo, EndpointJwksResponse, Reason},
},
http,
metrics::{CacheOutcome, Metrics},
rate_limiter::WakeComputeRateLimiter,
scram, EndpointCacheKey, EndpointId,
};
use crate::{cache::Cached, context::RequestMonitoring};
use ::http::{header::AUTHORIZATION, HeaderName};
use futures::TryFutureExt;
use std::{sync::Arc, time::Duration};
use tokio::time::Instant;
use tokio_postgres::config::SslMode;
use tracing::{debug, info, info_span, warn, Instrument};
const X_REQUEST_ID: HeaderName = HeaderName::from_static("x-request-id");

View File

@@ -1,5 +1,4 @@
use std::error::Error as StdError;
use std::{fmt, io};
use std::{error::Error as StdError, fmt, io};
use measured::FixedCardinalityLabel;

View File

@@ -1,18 +1,19 @@
use std::convert::Infallible;
use std::net::TcpListener;
use std::sync::{Arc, Mutex};
use anyhow::{anyhow, bail};
use hyper0::header::CONTENT_TYPE;
use hyper0::{Body, Request, Response, StatusCode};
use measured::text::BufferedTextEncoder;
use measured::MetricGroup;
use hyper0::{header::CONTENT_TYPE, Body, Request, Response, StatusCode};
use measured::{text::BufferedTextEncoder, MetricGroup};
use metrics::NeonMetrics;
use std::{
convert::Infallible,
net::TcpListener,
sync::{Arc, Mutex},
};
use tracing::{info, info_span};
use utils::http::endpoint::{self, request_span};
use utils::http::error::ApiError;
use utils::http::json::json_response;
use utils::http::{RouterBuilder, RouterService};
use utils::http::{
endpoint::{self, request_span},
error::ApiError,
json::json_response,
RouterBuilder, RouterService,
};
use crate::jemalloc;

View File

@@ -8,18 +8,19 @@ use std::time::Duration;
use anyhow::bail;
use bytes::Bytes;
use http::Method;
use http_body_util::BodyExt;
use hyper::body::Body;
pub(crate) use reqwest::{Request, Response};
use reqwest_middleware::RequestBuilder;
pub(crate) use reqwest_middleware::{ClientWithMiddleware, Error};
pub(crate) use reqwest_retry::policies::ExponentialBackoff;
pub(crate) use reqwest_retry::RetryTransientMiddleware;
use serde::de::DeserializeOwned;
use crate::metrics::{ConsoleRequest, Metrics};
use crate::url::ApiUrl;
pub(crate) use reqwest::{Request, Response};
pub(crate) use reqwest_middleware::{ClientWithMiddleware, Error};
pub(crate) use reqwest_retry::{policies::ExponentialBackoff, RetryTransientMiddleware};
use crate::{
metrics::{ConsoleRequest, Metrics},
url::ApiUrl,
};
use reqwest_middleware::RequestBuilder;
/// This is the preferred way to create new http clients,
/// because it takes care of observability (OpenTelemetry).
@@ -94,19 +95,9 @@ impl Endpoint {
/// Return a [builder](RequestBuilder) for a `GET` request,
/// accepting a closure to modify the url path segments for more complex paths queries.
pub(crate) fn get_with_url(&self, f: impl for<'a> FnOnce(&'a mut ApiUrl)) -> RequestBuilder {
self.request_with_url(Method::GET, f)
}
/// Return a [builder](RequestBuilder) for a request,
/// accepting a closure to modify the url path segments for more complex paths queries.
pub(crate) fn request_with_url(
&self,
method: Method,
f: impl for<'a> FnOnce(&'a mut ApiUrl),
) -> RequestBuilder {
let mut url = self.endpoint.clone();
f(&mut url);
self.client.request(method, url.into_inner())
self.client.get(url.into_inner())
}
/// Execute a [request](reqwest::Request).
@@ -151,9 +142,8 @@ pub(crate) async fn parse_json_body_with_limit<D: DeserializeOwned>(
#[cfg(test)]
mod tests {
use reqwest::Client;
use super::*;
use reqwest::Client;
#[test]
fn optional_query_params() -> anyhow::Result<()> {

View File

@@ -1,8 +1,6 @@
use std::hash::BuildHasherDefault;
use std::marker::PhantomData;
use std::num::NonZeroUsize;
use std::ops::Index;
use std::sync::OnceLock;
use std::{
hash::BuildHasherDefault, marker::PhantomData, num::NonZeroUsize, ops::Index, sync::OnceLock,
};
use lasso::{Capacity, MemoryLimits, Spur, ThreadedRodeo};
use rustc_hash::FxHasher;
@@ -210,9 +208,10 @@ impl From<ProjectId> for ProjectIdInt {
mod tests {
use std::sync::OnceLock;
use super::InternId;
use crate::intern::StringInterner;
use super::InternId;
struct MyId;
impl InternId for MyId {
fn get_interner() -> &'static StringInterner<Self> {
@@ -223,8 +222,7 @@ mod tests {
#[test]
fn push_many_strings() {
use rand::rngs::StdRng;
use rand::{Rng, SeedableRng};
use rand::{rngs::StdRng, Rng, SeedableRng};
use rand_distr::Zipf;
let endpoint_dist = Zipf::new(500000, 0.8).unwrap();

View File

@@ -1,12 +1,14 @@
use std::marker::PhantomData;
use measured::label::NoLabels;
use measured::metric::gauge::GaugeState;
use measured::metric::group::Encoding;
use measured::metric::name::MetricNameEncoder;
use measured::metric::{MetricEncoding, MetricFamilyEncoding, MetricType};
use measured::text::TextEncoder;
use measured::{LabelGroup, MetricGroup};
use measured::{
label::NoLabels,
metric::{
gauge::GaugeState, group::Encoding, name::MetricNameEncoder, MetricEncoding,
MetricFamilyEncoding, MetricType,
},
text::TextEncoder,
LabelGroup, MetricGroup,
};
use tikv_jemalloc_ctl::{config, epoch, epoch_mib, stats, version};
pub struct MetricRecorder {

View File

@@ -94,7 +94,6 @@ pub mod auth;
pub mod cache;
pub mod cancellation;
pub mod compute;
pub mod compute_ctl;
pub mod config;
pub mod console_redirect_proxy;
pub mod context;

View File

@@ -1,10 +1,14 @@
use tracing::Subscriber;
use tracing_subscriber::filter::{EnvFilter, LevelFilter};
use tracing_subscriber::fmt::format::{Format, Full};
use tracing_subscriber::fmt::time::SystemTime;
use tracing_subscriber::fmt::{FormatEvent, FormatFields};
use tracing_subscriber::prelude::*;
use tracing_subscriber::registry::LookupSpan;
use tracing_subscriber::{
filter::{EnvFilter, LevelFilter},
fmt::{
format::{Format, Full},
time::SystemTime,
FormatEvent, FormatFields,
},
prelude::*,
registry::LookupSpan,
};
/// Initialize logging and OpenTelemetry tracing and exporter.
///

View File

@@ -1,16 +1,14 @@
use std::sync::{Arc, OnceLock};
use lasso::ThreadedRodeo;
use measured::label::{
FixedCardinalitySet, LabelGroupSet, LabelName, LabelSet, LabelValue, StaticLabelSet,
};
use measured::metric::histogram::Thresholds;
use measured::metric::name::MetricName;
use measured::{
label::{FixedCardinalitySet, LabelGroupSet, LabelName, LabelSet, LabelValue, StaticLabelSet},
metric::{histogram::Thresholds, name::MetricName},
Counter, CounterVec, FixedCardinalityLabel, Gauge, Histogram, HistogramVec, LabelGroup,
MetricGroup,
};
use metrics::{CounterPairAssoc, CounterPairVec, HyperLogLog, HyperLogLogVec};
use tokio::time::{self, Instant};
use crate::control_plane::messages::ColdStartInfo;

View File

@@ -1,9 +1,11 @@
//! Proxy Protocol V2 implementation
use std::io;
use std::net::SocketAddr;
use std::pin::Pin;
use std::task::{Context, Poll};
use std::{
io,
net::SocketAddr,
pin::Pin,
task::{Context, Poll},
};
use bytes::BytesMut;
use pin_project_lite::pin_project;

View File

@@ -1,23 +1,24 @@
use crate::{
auth::backend::ComputeCredentialKeys,
compute::COULD_NOT_CONNECT,
compute::{self, PostgresConnection},
config::RetryConfig,
context::RequestMonitoring,
control_plane::{self, errors::WakeComputeError, locks::ApiLocks, CachedNodeInfo, NodeInfo},
error::ReportableError,
metrics::{ConnectOutcome, ConnectionFailureKind, Metrics, RetriesMetricGroup, RetryType},
proxy::{
retry::{retry_after, should_retry, CouldRetry},
wake_compute::wake_compute,
},
Host,
};
use async_trait::async_trait;
use pq_proto::StartupMessageParams;
use tokio::time;
use tracing::{debug, info, warn};
use super::retry::ShouldRetryWakeCompute;
use crate::auth::backend::ComputeCredentialKeys;
use crate::compute::{self, PostgresConnection, COULD_NOT_CONNECT};
use crate::config::RetryConfig;
use crate::context::RequestMonitoring;
use crate::control_plane::errors::WakeComputeError;
use crate::control_plane::locks::ApiLocks;
use crate::control_plane::{self, CachedNodeInfo, NodeInfo};
use crate::error::ReportableError;
use crate::metrics::{
ConnectOutcome, ConnectionFailureKind, Metrics, RetriesMetricGroup, RetryType,
};
use crate::proxy::retry::{retry_after, should_retry, CouldRetry};
use crate::proxy::wake_compute::wake_compute;
use crate::Host;
const CONNECT_TIMEOUT: time::Duration = time::Duration::from_secs(2);

View File

@@ -1,11 +1,11 @@
use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
use tracing::info;
use std::future::poll_fn;
use std::io;
use std::pin::Pin;
use std::task::{ready, Context, Poll};
use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
use tracing::info;
#[derive(Debug)]
enum TransferState {
Running(CopyBuffer),
@@ -256,9 +256,8 @@ impl CopyBuffer {
#[cfg(test)]
mod tests {
use tokio::io::AsyncWriteExt;
use super::*;
use tokio::io::AsyncWriteExt;
#[tokio::test]
async fn test_client_to_compute() {

View File

@@ -1,19 +1,21 @@
use bytes::Buf;
use pq_proto::framed::Framed;
use pq_proto::{
BeMessage as Be, CancelKeyData, FeStartupPacket, ProtocolVersion, StartupMessageParams,
framed::Framed, BeMessage as Be, CancelKeyData, FeStartupPacket, ProtocolVersion,
StartupMessageParams,
};
use thiserror::Error;
use tokio::io::{AsyncRead, AsyncWrite};
use tracing::{info, warn};
use crate::auth::endpoint_sni;
use crate::config::{TlsConfig, PG_ALPN_PROTOCOL};
use crate::context::RequestMonitoring;
use crate::error::ReportableError;
use crate::metrics::Metrics;
use crate::proxy::ERR_INSECURE_CONNECTION;
use crate::stream::{PqStream, Stream, StreamUpgradeError};
use crate::{
auth::endpoint_sni,
config::{TlsConfig, PG_ALPN_PROTOCOL},
context::RequestMonitoring,
error::ReportableError,
metrics::Metrics,
proxy::ERR_INSECURE_CONNECTION,
stream::{PqStream, Stream, StreamUpgradeError},
};
#[derive(Error, Debug)]
pub(crate) enum HandshakeError {

View File

@@ -7,32 +7,40 @@ pub(crate) mod handshake;
pub(crate) mod passthrough;
pub(crate) mod retry;
pub(crate) mod wake_compute;
use std::sync::Arc;
pub use copy_bidirectional::copy_bidirectional_client_compute;
pub use copy_bidirectional::ErrorSource;
pub use copy_bidirectional::{copy_bidirectional_client_compute, ErrorSource};
use crate::config::ProxyProtocolV2;
use crate::{
auth,
cancellation::{self, CancellationHandlerMain, CancellationHandlerMainInternal},
compute,
config::{ProxyConfig, TlsConfig},
context::RequestMonitoring,
error::ReportableError,
metrics::{Metrics, NumClientConnectionsGuard},
protocol2::read_proxy_protocol,
proxy::handshake::{handshake, HandshakeData},
rate_limiter::EndpointRateLimiter,
stream::{PqStream, Stream},
EndpointCacheKey,
};
use futures::TryFutureExt;
use itertools::Itertools;
use once_cell::sync::OnceCell;
use pq_proto::{BeMessage as Be, StartupMessageParams};
use regex::Regex;
use smol_str::{format_smolstr, SmolStr};
use std::sync::Arc;
use thiserror::Error;
use tokio::io::{AsyncRead, AsyncWrite, AsyncWriteExt};
use tokio_util::sync::CancellationToken;
use tracing::{error, info, warn, Instrument};
use self::connect_compute::{connect_to_compute, TcpMechanism};
use self::passthrough::ProxyPassthrough;
use crate::cancellation::{self, CancellationHandlerMain, CancellationHandlerMainInternal};
use crate::config::{ProxyConfig, ProxyProtocolV2, TlsConfig};
use crate::context::RequestMonitoring;
use crate::error::ReportableError;
use crate::metrics::{Metrics, NumClientConnectionsGuard};
use crate::protocol2::read_proxy_protocol;
use crate::proxy::handshake::{handshake, HandshakeData};
use crate::rate_limiter::EndpointRateLimiter;
use crate::stream::{PqStream, Stream};
use crate::{auth, compute, EndpointCacheKey};
use self::{
connect_compute::{connect_to_compute, TcpMechanism},
passthrough::ProxyPassthrough,
};
const ERR_INSECURE_CONNECTION: &str = "connection is insecure (try using `sslmode=require`)";

View File

@@ -1,14 +1,16 @@
use crate::{
cancellation,
compute::PostgresConnection,
control_plane::messages::MetricsAuxInfo,
metrics::{Direction, Metrics, NumClientConnectionsGuard, NumConnectionRequestsGuard},
stream::Stream,
usage_metrics::{Ids, MetricCounterRecorder, USAGE_METRICS},
};
use tokio::io::{AsyncRead, AsyncWrite};
use tracing::info;
use utils::measured_stream::MeasuredStream;
use super::copy_bidirectional::ErrorSource;
use crate::cancellation;
use crate::compute::PostgresConnection;
use crate::control_plane::messages::MetricsAuxInfo;
use crate::metrics::{Direction, Metrics, NumClientConnectionsGuard, NumConnectionRequestsGuard};
use crate::stream::Stream;
use crate::usage_metrics::{Ids, MetricCounterRecorder, USAGE_METRICS};
/// Forward bytes in both directions (client <-> compute).
#[tracing::instrument(skip_all)]

View File

@@ -1,11 +1,7 @@
use std::error::Error;
use std::io;
use crate::{compute, config::RetryConfig};
use std::{error::Error, io};
use tokio::time;
use crate::compute;
use crate::config::RetryConfig;
pub(crate) trait CouldRetry {
/// Returns true if the error could be retried
fn could_retry(&self) -> bool;

View File

@@ -6,6 +6,7 @@
use std::fmt::Debug;
use super::*;
use bytes::{Bytes, BytesMut};
use futures::{SinkExt, StreamExt};
use postgres_protocol::message::frontend;
@@ -13,8 +14,6 @@ use tokio::io::{AsyncReadExt, DuplexStream};
use tokio_postgres::tls::TlsConnect;
use tokio_util::codec::{Decoder, Encoder};
use super::*;
enum Intercept {
None,
Methods,

View File

@@ -4,16 +4,6 @@ mod mitm;
use std::time::Duration;
use anyhow::{bail, Context};
use async_trait::async_trait;
use http::StatusCode;
use retry::{retry_after, ShouldRetryWakeCompute};
use rstest::rstest;
use rustls::pki_types;
use tokio_postgres::config::SslMode;
use tokio_postgres::tls::{MakeTlsConnect, NoTls};
use tokio_postgres_rustls::{MakeRustlsConnect, RustlsStream};
use super::connect_compute::ConnectMechanism;
use super::retry::CouldRetry;
use super::*;
@@ -28,6 +18,15 @@ use crate::control_plane::provider::{
use crate::control_plane::{self, CachedNodeInfo, NodeInfo};
use crate::error::ErrorKind;
use crate::{sasl, scram, BranchId, EndpointId, ProjectId};
use anyhow::{bail, Context};
use async_trait::async_trait;
use http::StatusCode;
use retry::{retry_after, ShouldRetryWakeCompute};
use rstest::rstest;
use rustls::pki_types;
use tokio_postgres::config::SslMode;
use tokio_postgres::tls::{MakeTlsConnect, NoTls};
use tokio_postgres_rustls::{MakeRustlsConnect, RustlsStream};
/// Generate a set of TLS certificates: CA + server.
fn generate_certs(
@@ -337,8 +336,7 @@ async fn scram_auth_mock() -> anyhow::Result<()> {
generate_tls_config("generic-project-name.localhost", "localhost")?;
let proxy = tokio::spawn(dummy_proxy(client, Some(server_config), Scram::mock()));
use rand::distributions::Alphanumeric;
use rand::Rng;
use rand::{distributions::Alphanumeric, Rng};
let password: String = rand::thread_rng()
.sample_iter(&Alphanumeric)
.take(rand::random::<u8>() as usize)

View File

@@ -1,17 +1,16 @@
use hyper::StatusCode;
use tracing::{error, info, warn};
use super::connect_compute::ComputeConnectBackend;
use crate::config::RetryConfig;
use crate::context::RequestMonitoring;
use crate::control_plane::errors::WakeComputeError;
use crate::control_plane::messages::{ControlPlaneError, Reason};
use crate::control_plane::provider::CachedNodeInfo;
use crate::control_plane::{errors::WakeComputeError, provider::CachedNodeInfo};
use crate::metrics::{
ConnectOutcome, ConnectionFailuresBreakdownGroup, Metrics, RetriesMetricGroup, RetryType,
WakeupFailureKind,
};
use crate::proxy::retry::{retry_after, should_retry};
use hyper::StatusCode;
use tracing::{error, info, warn};
use super::connect_compute::ComputeConnectBackend;
pub(crate) async fn wake_compute<B: ComputeConnectBackend>(
num_retries: &mut u32,

View File

@@ -1,5 +1,7 @@
use std::hash::Hash;
use std::sync::atomic::{AtomicUsize, Ordering};
use std::{
hash::Hash,
sync::atomic::{AtomicUsize, Ordering},
};
use ahash::RandomState;
use dashmap::DashMap;

View File

@@ -1,12 +1,10 @@
//! Algorithms for controlling concurrency limits.
use std::pin::pin;
use std::sync::Arc;
use std::time::Duration;
use parking_lot::Mutex;
use tokio::sync::Notify;
use tokio::time::error::Elapsed;
use tokio::time::Instant;
use std::{pin::pin, sync::Arc, time::Duration};
use tokio::{
sync::Notify,
time::{error::Elapsed, Instant},
};
use self::aimd::Aimd;

View File

@@ -60,11 +60,12 @@ impl LimitAlgorithm for Aimd {
mod tests {
use std::time::Duration;
use super::*;
use crate::rate_limiter::limit_algorithm::{
DynamicLimiter, RateLimitAlgorithm, RateLimiterConfig,
};
use super::*;
#[tokio::test(start_paused = true)]
async fn increase_decrease() {
let config = RateLimiterConfig {

View File

@@ -1,14 +1,17 @@
use std::borrow::Cow;
use std::collections::hash_map::RandomState;
use std::hash::{BuildHasher, Hash};
use std::sync::atomic::{AtomicUsize, Ordering};
use std::sync::Mutex;
use std::{
borrow::Cow,
collections::hash_map::RandomState,
hash::{BuildHasher, Hash},
sync::{
atomic::{AtomicUsize, Ordering},
Mutex,
},
};
use anyhow::bail;
use dashmap::DashMap;
use itertools::Itertools;
use rand::rngs::StdRng;
use rand::{Rng, SeedableRng};
use rand::{rngs::StdRng, Rng, SeedableRng};
use tokio::time::{Duration, Instant};
use tracing::info;
@@ -240,17 +243,14 @@ impl<K: Hash + Eq, R: Rng, S: BuildHasher + Clone> BucketRateLimiter<K, R, S> {
#[cfg(test)]
mod tests {
use std::hash::BuildHasherDefault;
use std::time::Duration;
use std::{hash::BuildHasherDefault, time::Duration};
use rand::SeedableRng;
use rustc_hash::FxHasher;
use tokio::time;
use super::{BucketRateLimiter, WakeComputeRateLimiter};
use crate::intern::EndpointIdInt;
use crate::rate_limiter::RateBucketInfo;
use crate::EndpointId;
use crate::{intern::EndpointIdInt, rate_limiter::RateBucketInfo, EndpointId};
#[test]
fn rate_bucket_rpi() {

View File

@@ -2,11 +2,13 @@ mod leaky_bucket;
mod limit_algorithm;
mod limiter;
pub use leaky_bucket::{EndpointRateLimiter, LeakyBucketConfig, LeakyBucketRateLimiter};
#[cfg(test)]
pub(crate) use limit_algorithm::aimd::Aimd;
pub(crate) use limit_algorithm::{
DynamicLimiter, Outcome, RateLimitAlgorithm, RateLimiterConfig, Token,
};
pub(crate) use limiter::GlobalRateLimiter;
pub use leaky_bucket::{EndpointRateLimiter, LeakyBucketConfig, LeakyBucketRateLimiter};
pub use limiter::{BucketRateLimiter, RateBucketInfo, WakeComputeRateLimiter};

View File

@@ -5,10 +5,13 @@ use redis::AsyncCommands;
use tokio::sync::Mutex;
use uuid::Uuid;
use super::connection_with_credentials_provider::ConnectionWithCredentialsProvider;
use super::notifications::{CancelSession, Notification, PROXY_CHANNEL_NAME};
use crate::rate_limiter::{GlobalRateLimiter, RateBucketInfo};
use super::{
connection_with_credentials_provider::ConnectionWithCredentialsProvider,
notifications::{CancelSession, Notification, PROXY_CHANNEL_NAME},
};
pub trait CancellationPublisherMut: Send + Sync + 'static {
#[allow(async_fn_in_trait)]
async fn try_publish(

View File

@@ -1,9 +1,10 @@
use std::sync::Arc;
use std::time::Duration;
use std::{sync::Arc, time::Duration};
use futures::FutureExt;
use redis::aio::{ConnectionLike, MultiplexedConnection};
use redis::{ConnectionInfo, IntoConnectionInfo, RedisConnectionInfo, RedisResult};
use redis::{
aio::{ConnectionLike, MultiplexedConnection},
ConnectionInfo, IntoConnectionInfo, RedisConnectionInfo, RedisResult,
};
use tokio::task::JoinHandle;
use tracing::{debug, error, info, warn};

View File

@@ -1,5 +1,4 @@
use std::convert::Infallible;
use std::sync::Arc;
use std::{convert::Infallible, sync::Arc};
use futures::StreamExt;
use pq_proto::CancelKeyData;
@@ -9,10 +8,12 @@ use tokio_util::sync::CancellationToken;
use uuid::Uuid;
use super::connection_with_credentials_provider::ConnectionWithCredentialsProvider;
use crate::cache::project_info::ProjectInfoCache;
use crate::cancellation::{CancelMap, CancellationHandler};
use crate::intern::{ProjectIdInt, RoleNameInt};
use crate::metrics::{Metrics, RedisErrors, RedisEventsCount};
use crate::{
cache::project_info::ProjectInfoCache,
cancellation::{CancelMap, CancellationHandler},
intern::{ProjectIdInt, RoleNameInt},
metrics::{Metrics, RedisErrors, RedisEventsCount},
};
const CPLANE_CHANNEL_NAME: &str = "neondb-proxy-ws-updates";
pub(crate) const PROXY_CHANNEL_NAME: &str = "neondb-proxy-to-proxy-updates";
@@ -268,10 +269,10 @@ where
#[cfg(test)]
mod tests {
use serde_json::json;
use crate::{ProjectId, RoleName};
use super::*;
use crate::{ProjectId, RoleName};
use serde_json::json;
#[test]
fn parse_allowed_ips() -> anyhow::Result<()> {

View File

@@ -1,8 +1,7 @@
//! Definitions for SASL messages.
use pq_proto::{BeAuthenticationSaslMessage, BeMessage};
use crate::parse::{split_at_const, split_cstr};
use pq_proto::{BeAuthenticationSaslMessage, BeMessage};
/// SASL-specific payload of [`PasswordMessage`](pq_proto::FeMessage::PasswordMessage).
#[derive(Debug)]

View File

@@ -10,14 +10,13 @@ mod channel_binding;
mod messages;
mod stream;
use crate::error::{ReportableError, UserFacingError};
use std::io;
use thiserror::Error;
pub(crate) use channel_binding::ChannelBinding;
pub(crate) use messages::FirstMessage;
pub(crate) use stream::{Outcome, SaslStream};
use thiserror::Error;
use crate::error::{ReportableError, UserFacingError};
/// Fine-grained auth errors help in writing tests.
#[derive(Error, Debug)]

View File

@@ -1,14 +1,11 @@
//! Abstraction for the string-oriented SASL protocols.
use super::{messages::ServerMessage, Mechanism};
use crate::stream::PqStream;
use std::io;
use tokio::io::{AsyncRead, AsyncWrite};
use tracing::info;
use super::messages::ServerMessage;
use super::Mechanism;
use crate::stream::PqStream;
/// Abstracts away all peculiarities of the libpq's protocol.
pub(crate) struct SaslStream<'a, S> {
/// The underlying stream.

View File

@@ -69,9 +69,7 @@ impl CountMinSketch {
#[cfg(test)]
mod tests {
use rand::rngs::StdRng;
use rand::seq::SliceRandom;
use rand::{Rng, SeedableRng};
use rand::{rngs::StdRng, seq::SliceRandom, Rng, SeedableRng};
use super::CountMinSketch;

View File

@@ -209,8 +209,7 @@ impl sasl::Mechanism for Exchange<'_> {
type Output = super::ScramKey;
fn exchange(mut self, input: &str) -> sasl::Result<sasl::Step<Self, Self::Output>> {
use sasl::Step;
use ExchangeState;
use {sasl::Step, ExchangeState};
match &self.state {
ExchangeState::Initial(init) => {
match init.transition(self.secret, &self.tls_server_end_point, input)? {

View File

@@ -1,12 +1,11 @@
//! Definitions for SCRAM messages.
use std::fmt;
use std::ops::Range;
use super::base64_decode_array;
use super::key::{ScramKey, SCRAM_KEY_LEN};
use super::signature::SignatureBuilder;
use crate::sasl::ChannelBinding;
use std::fmt;
use std::ops::Range;
/// Faithfully taken from PostgreSQL.
pub(crate) const SCRAM_RAW_NONCE_LEN: usize = 18;

View File

@@ -16,9 +16,10 @@ mod signature;
pub mod threadpool;
pub(crate) use exchange::{exchange, Exchange};
use hmac::{Hmac, Mac};
pub(crate) use key::ScramKey;
pub(crate) use secret::ServerSecret;
use hmac::{Hmac, Mac};
use sha2::{Digest, Sha256};
const SCRAM_SHA_256: &str = "SCRAM-SHA-256";
@@ -58,11 +59,13 @@ fn sha256<'a>(parts: impl IntoIterator<Item = &'a [u8]>) -> [u8; 32] {
#[cfg(test)]
mod tests {
use super::threadpool::ThreadPool;
use super::{Exchange, ServerSecret};
use crate::intern::EndpointIdInt;
use crate::sasl::{Mechanism, Step};
use crate::EndpointId;
use crate::{
intern::EndpointIdInt,
sasl::{Mechanism, Step},
EndpointId,
};
use super::{threadpool::ThreadPool, Exchange, ServerSecret};
#[test]
fn snapshot() {

View File

@@ -1,6 +1,7 @@
use hmac::digest::consts::U32;
use hmac::digest::generic_array::GenericArray;
use hmac::{Hmac, Mac};
use hmac::{
digest::{consts::U32, generic_array::GenericArray},
Hmac, Mac,
};
use sha2::Sha256;
pub(crate) struct Pbkdf2 {
@@ -65,11 +66,10 @@ impl Pbkdf2 {
#[cfg(test)]
mod tests {
use super::Pbkdf2;
use pbkdf2::pbkdf2_hmac_array;
use sha2::Sha256;
use super::Pbkdf2;
#[test]
fn works() {
let salt = b"sodium chloride";

View File

@@ -4,21 +4,28 @@
//! 1. Fairness per endpoint.
//! 2. Yield support for high iteration counts.
use std::cell::RefCell;
use std::future::Future;
use std::pin::Pin;
use std::sync::atomic::{AtomicUsize, Ordering};
use std::sync::{Arc, Weak};
use std::task::{Context, Poll};
use std::{
cell::RefCell,
future::Future,
pin::Pin,
sync::{
atomic::{AtomicUsize, Ordering},
Arc, Weak,
},
task::{Context, Poll},
};
use futures::FutureExt;
use rand::rngs::SmallRng;
use rand::{Rng, SeedableRng};
use rand::Rng;
use rand::{rngs::SmallRng, SeedableRng};
use crate::{
intern::EndpointIdInt,
metrics::{ThreadPoolMetrics, ThreadPoolWorkerId},
scram::countmin::CountMinSketch,
};
use super::pbkdf2::Pbkdf2;
use crate::intern::EndpointIdInt;
use crate::metrics::{ThreadPoolMetrics, ThreadPoolWorkerId};
use crate::scram::countmin::CountMinSketch;
pub struct ThreadPool {
runtime: Option<tokio::runtime::Runtime>,
@@ -188,9 +195,10 @@ impl Drop for JobHandle {
#[cfg(test)]
mod tests {
use super::*;
use crate::EndpointId;
use super::*;
#[tokio::test]
async fn hash_is_correct() {
let pool = ThreadPool::new(1);

View File

@@ -1,44 +1,47 @@
use std::io;
use std::sync::Arc;
use std::time::Duration;
use std::{io, sync::Arc, time::Duration};
use async_trait::async_trait;
use hyper_util::rt::{TokioExecutor, TokioIo, TokioTimer};
use p256::ecdsa::SigningKey;
use p256::elliptic_curve::JwkEcKey;
use p256::{ecdsa::SigningKey, elliptic_curve::JwkEcKey};
use rand::rngs::OsRng;
use tokio::net::{lookup_host, TcpStream};
use tracing::field::display;
use tracing::{debug, info};
use tracing::{debug, field::display, info};
use super::conn_pool::poll_client;
use super::conn_pool_lib::{Client, ConnInfo, GlobalConnPool};
use super::http_conn_pool::{self, poll_http2_client, Send};
use super::local_conn_pool::{self, LocalClient, LocalConnPool, EXT_NAME, EXT_SCHEMA, EXT_VERSION};
use crate::auth::backend::local::StaticAuthRules;
use crate::auth::backend::{ComputeCredentials, ComputeUserInfo};
use crate::auth::{self, check_peer_addr_is_in_list, AuthError};
use crate::compute_ctl::{
ComputeCtlError, ExtensionInstallRequest, Privilege, SetRoleGrantsRequest,
use crate::{
auth::{
self,
backend::{local::StaticAuthRules, ComputeCredentials, ComputeUserInfo},
check_peer_addr_is_in_list, AuthError,
},
compute,
config::ProxyConfig,
context::RequestMonitoring,
control_plane::{
errors::{GetAuthInfoError, WakeComputeError},
locks::ApiLocks,
provider::ApiLockError,
CachedNodeInfo,
},
error::{ErrorKind, ReportableError, UserFacingError},
intern::EndpointIdInt,
proxy::{
connect_compute::ConnectMechanism,
retry::{CouldRetry, ShouldRetryWakeCompute},
},
rate_limiter::EndpointRateLimiter,
EndpointId, Host,
};
use super::{
conn_pool::{poll_client, Client, ConnInfo, GlobalConnPool},
http_conn_pool::{self, poll_http2_client},
local_conn_pool::{self, LocalClient, LocalConnPool},
};
use crate::config::ProxyConfig;
use crate::context::RequestMonitoring;
use crate::control_plane::errors::{GetAuthInfoError, WakeComputeError};
use crate::control_plane::locks::ApiLocks;
use crate::control_plane::provider::ApiLockError;
use crate::control_plane::CachedNodeInfo;
use crate::error::{ErrorKind, ReportableError, UserFacingError};
use crate::intern::EndpointIdInt;
use crate::proxy::connect_compute::ConnectMechanism;
use crate::proxy::retry::{CouldRetry, ShouldRetryWakeCompute};
use crate::rate_limiter::EndpointRateLimiter;
use crate::{compute, EndpointId, Host};
pub(crate) struct PoolingBackend {
pub(crate) http_conn_pool: Arc<super::http_conn_pool::GlobalConnPool<Send>>,
pub(crate) http_conn_pool: Arc<super::http_conn_pool::GlobalConnPool>,
pub(crate) local_pool: Arc<LocalConnPool<tokio_postgres::Client>>,
pub(crate) pool: Arc<GlobalConnPool<tokio_postgres::Client>>,
pub(crate) config: &'static ProxyConfig,
pub(crate) auth_backend: &'static crate::auth::Backend<'static, ()>,
pub(crate) endpoint_rate_limiter: Arc<EndpointRateLimiter>,
@@ -204,7 +207,7 @@ impl PoolingBackend {
&self,
ctx: &RequestMonitoring,
conn_info: ConnInfo,
) -> Result<http_conn_pool::Client<Send>, HttpConnError> {
) -> Result<http_conn_pool::Client, HttpConnError> {
info!("pool: looking for an existing connection");
if let Some(client) = self.http_conn_pool.get(ctx, &conn_info) {
return Ok(client);
@@ -254,50 +257,16 @@ impl PoolingBackend {
return Ok(client);
}
let local_backend = match &self.auth_backend {
auth::Backend::ControlPlane(_, ()) => {
unreachable!("only local_proxy can connect to local postgres")
}
auth::Backend::Local(local) => local,
};
#[allow(unreachable_code, clippy::todo)]
if !self.local_pool.initialized(&conn_info) {
// only install and grant usage one at a time.
let _permit = local_backend.initialize.acquire().await.unwrap();
// check again for race
if !self.local_pool.initialized(&conn_info) {
local_backend
.compute_ctl
.install_extension(&ExtensionInstallRequest {
extension: EXT_NAME,
database: conn_info.dbname.clone(),
// todo: move to const or config
version: EXT_VERSION,
})
.await?;
local_backend
.compute_ctl
.grant_role(&SetRoleGrantsRequest {
// fixed for pg_session_jwt
schema: EXT_SCHEMA,
privileges: vec![Privilege::Usage],
database: conn_info.dbname.clone(),
role: conn_info.user_info.user.clone(),
})
.await?;
self.local_pool.set_initialized(&conn_info);
}
}
let conn_id = uuid::Uuid::new_v4();
tracing::Span::current().record("conn_id", display(conn_id));
info!(%conn_id, "local_pool: opening a new connection '{conn_info}'");
let mut node_info = local_backend.node_info.clone();
let mut node_info = match &self.auth_backend {
auth::Backend::ControlPlane(_, ()) => {
unreachable!("only local_proxy can connect to local postgres")
}
auth::Backend::Local(local) => local.node_info.clone(),
};
let (key, jwk) = create_random_jwk();
@@ -362,8 +331,6 @@ pub(crate) enum HttpConnError {
#[error("could not parse JWT payload")]
JwtPayloadError(serde_json::Error),
#[error("could not install extension: {0}")]
ComputeCtl(#[from] ComputeCtlError),
#[error("could not get auth info")]
GetAuthInfo(#[from] GetAuthInfoError),
#[error("user not authenticated")]
@@ -388,7 +355,6 @@ impl ReportableError for HttpConnError {
HttpConnError::ConnectionClosedAbruptly(_) => ErrorKind::Compute,
HttpConnError::PostgresConnectionError(p) => p.get_error_kind(),
HttpConnError::LocalProxyConnectionError(_) => ErrorKind::Compute,
HttpConnError::ComputeCtl(_) => ErrorKind::Service,
HttpConnError::JwtPayloadError(_) => ErrorKind::User,
HttpConnError::GetAuthInfo(a) => a.get_error_kind(),
HttpConnError::AuthError(a) => a.get_error_kind(),
@@ -404,7 +370,6 @@ impl UserFacingError for HttpConnError {
HttpConnError::ConnectionClosedAbruptly(_) => self.to_string(),
HttpConnError::PostgresConnectionError(p) => p.to_string(),
HttpConnError::LocalProxyConnectionError(p) => p.to_string(),
HttpConnError::ComputeCtl(_) => "could not set up the JWT authorization database extension".to_string(),
HttpConnError::JwtPayloadError(p) => p.to_string(),
HttpConnError::GetAuthInfo(c) => c.to_string_client(),
HttpConnError::AuthError(c) => c.to_string_client(),
@@ -421,7 +386,6 @@ impl CouldRetry for HttpConnError {
match self {
HttpConnError::PostgresConnectionError(e) => e.could_retry(),
HttpConnError::LocalProxyConnectionError(e) => e.could_retry(),
HttpConnError::ComputeCtl(_) => false,
HttpConnError::ConnectionClosedAbruptly(_) => false,
HttpConnError::JwtPayloadError(_) => false,
HttpConnError::GetAuthInfo(_) => false,
@@ -525,7 +489,7 @@ impl ConnectMechanism for TokioMechanism {
}
struct HyperMechanism {
pool: Arc<http_conn_pool::GlobalConnPool<Send>>,
pool: Arc<http_conn_pool::GlobalConnPool>,
conn_info: ConnInfo,
conn_id: uuid::Uuid,
@@ -535,7 +499,7 @@ struct HyperMechanism {
#[async_trait]
impl ConnectMechanism for HyperMechanism {
type Connection = http_conn_pool::Client<Send>;
type Connection = http_conn_pool::Client;
type ConnectError = HttpConnError;
type Error = HttpConnError;

View File

@@ -1,8 +1,10 @@
//! A set for cancelling random http connections
use std::hash::{BuildHasher, BuildHasherDefault};
use std::num::NonZeroUsize;
use std::time::Duration;
use std::{
hash::{BuildHasher, BuildHasherDefault},
num::NonZeroUsize,
time::Duration,
};
use indexmap::IndexMap;
use parking_lot::Mutex;

View File

@@ -1,42 +1,68 @@
use std::fmt;
use std::pin::pin;
use std::sync::{Arc, Weak};
use std::task::{ready, Poll};
use futures::future::poll_fn;
use futures::Future;
use dashmap::DashMap;
use futures::{future::poll_fn, Future};
use parking_lot::RwLock;
use rand::Rng;
use smallvec::SmallVec;
use std::{collections::HashMap, pin::pin, sync::Arc, sync::Weak, time::Duration};
use std::{
fmt,
task::{ready, Poll},
};
use std::{
ops::Deref,
sync::atomic::{self, AtomicUsize},
};
use tokio::time::Instant;
use tokio_postgres::tls::NoTlsStream;
use tokio_postgres::{AsyncMessage, Socket};
use tokio_postgres::{AsyncMessage, ReadyForQueryStatus, Socket};
use tokio_util::sync::CancellationToken;
use tracing::{error, info, info_span, warn, Instrument};
use crate::context::RequestMonitoring;
use crate::control_plane::messages::MetricsAuxInfo;
use crate::metrics::Metrics;
use super::conn_pool_lib::{Client, ClientInnerExt, ConnInfo, GlobalConnPool};
#[cfg(test)]
use {
super::conn_pool_lib::GlobalConnPoolOptions,
crate::auth::backend::ComputeUserInfo,
std::{sync::atomic, time::Duration},
use crate::control_plane::messages::{ColdStartInfo, MetricsAuxInfo};
use crate::metrics::{HttpEndpointPoolsGuard, Metrics};
use crate::usage_metrics::{Ids, MetricCounter, USAGE_METRICS};
use crate::{
auth::backend::ComputeUserInfo, context::RequestMonitoring, DbName, EndpointCacheKey, RoleName,
};
use tracing::{debug, error, warn, Span};
use tracing::{info, info_span, Instrument};
use super::backend::HttpConnError;
#[derive(Debug, Clone)]
pub(crate) struct ConnInfoWithAuth {
pub(crate) conn_info: ConnInfo,
pub(crate) auth: AuthData,
}
#[derive(Debug, Clone)]
pub(crate) struct ConnInfo {
pub(crate) user_info: ComputeUserInfo,
pub(crate) dbname: DbName,
}
#[derive(Debug, Clone)]
pub(crate) enum AuthData {
Password(SmallVec<[u8; 16]>),
Jwt(String),
}
impl ConnInfo {
// hm, change to hasher to avoid cloning?
pub(crate) fn db_and_user(&self) -> (DbName, RoleName) {
(self.dbname.clone(), self.user_info.user.clone())
}
pub(crate) fn endpoint_cache_key(&self) -> Option<EndpointCacheKey> {
// We don't want to cache http connections for ephemeral endpoints.
if self.user_info.options.is_ephemeral() {
None
} else {
Some(self.user_info.endpoint_cache_key())
}
}
}
impl fmt::Display for ConnInfo {
// use custom display to avoid logging password
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
@@ -51,6 +77,402 @@ impl fmt::Display for ConnInfo {
}
}
struct ConnPoolEntry<C: ClientInnerExt> {
conn: ClientInner<C>,
_last_access: std::time::Instant,
}
// Per-endpoint connection pool, (dbname, username) -> DbUserConnPool
// Number of open connections is limited by the `max_conns_per_endpoint`.
pub(crate) struct EndpointConnPool<C: ClientInnerExt> {
pools: HashMap<(DbName, RoleName), DbUserConnPool<C>>,
total_conns: usize,
max_conns: usize,
_guard: HttpEndpointPoolsGuard<'static>,
global_connections_count: Arc<AtomicUsize>,
global_pool_size_max_conns: usize,
}
impl<C: ClientInnerExt> EndpointConnPool<C> {
fn get_conn_entry(&mut self, db_user: (DbName, RoleName)) -> Option<ConnPoolEntry<C>> {
let Self {
pools,
total_conns,
global_connections_count,
..
} = self;
pools.get_mut(&db_user).and_then(|pool_entries| {
pool_entries.get_conn_entry(total_conns, global_connections_count.clone())
})
}
fn remove_client(&mut self, db_user: (DbName, RoleName), conn_id: uuid::Uuid) -> bool {
let Self {
pools,
total_conns,
global_connections_count,
..
} = self;
if let Some(pool) = pools.get_mut(&db_user) {
let old_len = pool.conns.len();
pool.conns.retain(|conn| conn.conn.conn_id != conn_id);
let new_len = pool.conns.len();
let removed = old_len - new_len;
if removed > 0 {
global_connections_count.fetch_sub(removed, atomic::Ordering::Relaxed);
Metrics::get()
.proxy
.http_pool_opened_connections
.get_metric()
.dec_by(removed as i64);
}
*total_conns -= removed;
removed > 0
} else {
false
}
}
fn put(pool: &RwLock<Self>, conn_info: &ConnInfo, client: ClientInner<C>) {
let conn_id = client.conn_id;
if client.is_closed() {
info!(%conn_id, "pool: throwing away connection '{conn_info}' because connection is closed");
return;
}
let global_max_conn = pool.read().global_pool_size_max_conns;
if pool
.read()
.global_connections_count
.load(atomic::Ordering::Relaxed)
>= global_max_conn
{
info!(%conn_id, "pool: throwing away connection '{conn_info}' because pool is full");
return;
}
// return connection to the pool
let mut returned = false;
let mut per_db_size = 0;
let total_conns = {
let mut pool = pool.write();
if pool.total_conns < pool.max_conns {
let pool_entries = pool.pools.entry(conn_info.db_and_user()).or_default();
pool_entries.conns.push(ConnPoolEntry {
conn: client,
_last_access: std::time::Instant::now(),
});
returned = true;
per_db_size = pool_entries.conns.len();
pool.total_conns += 1;
pool.global_connections_count
.fetch_add(1, atomic::Ordering::Relaxed);
Metrics::get()
.proxy
.http_pool_opened_connections
.get_metric()
.inc();
}
pool.total_conns
};
// do logging outside of the mutex
if returned {
info!(%conn_id, "pool: returning connection '{conn_info}' back to the pool, total_conns={total_conns}, for this (db, user)={per_db_size}");
} else {
info!(%conn_id, "pool: throwing away connection '{conn_info}' because pool is full, total_conns={total_conns}");
}
}
}
impl<C: ClientInnerExt> Drop for EndpointConnPool<C> {
fn drop(&mut self) {
if self.total_conns > 0 {
self.global_connections_count
.fetch_sub(self.total_conns, atomic::Ordering::Relaxed);
Metrics::get()
.proxy
.http_pool_opened_connections
.get_metric()
.dec_by(self.total_conns as i64);
}
}
}
pub(crate) struct DbUserConnPool<C: ClientInnerExt> {
conns: Vec<ConnPoolEntry<C>>,
}
impl<C: ClientInnerExt> Default for DbUserConnPool<C> {
fn default() -> Self {
Self { conns: Vec::new() }
}
}
impl<C: ClientInnerExt> DbUserConnPool<C> {
fn clear_closed_clients(&mut self, conns: &mut usize) -> usize {
let old_len = self.conns.len();
self.conns.retain(|conn| !conn.conn.is_closed());
let new_len = self.conns.len();
let removed = old_len - new_len;
*conns -= removed;
removed
}
fn get_conn_entry(
&mut self,
conns: &mut usize,
global_connections_count: Arc<AtomicUsize>,
) -> Option<ConnPoolEntry<C>> {
let mut removed = self.clear_closed_clients(conns);
let conn = self.conns.pop();
if conn.is_some() {
*conns -= 1;
removed += 1;
}
global_connections_count.fetch_sub(removed, atomic::Ordering::Relaxed);
Metrics::get()
.proxy
.http_pool_opened_connections
.get_metric()
.dec_by(removed as i64);
conn
}
}
pub(crate) struct GlobalConnPool<C: ClientInnerExt> {
// endpoint -> per-endpoint connection pool
//
// That should be a fairly conteded map, so return reference to the per-endpoint
// pool as early as possible and release the lock.
global_pool: DashMap<EndpointCacheKey, Arc<RwLock<EndpointConnPool<C>>>>,
/// Number of endpoint-connection pools
///
/// [`DashMap::len`] iterates over all inner pools and acquires a read lock on each.
/// That seems like far too much effort, so we're using a relaxed increment counter instead.
/// It's only used for diagnostics.
global_pool_size: AtomicUsize,
/// Total number of connections in the pool
global_connections_count: Arc<AtomicUsize>,
config: &'static crate::config::HttpConfig,
}
#[derive(Debug, Clone, Copy)]
pub struct GlobalConnPoolOptions {
// Maximum number of connections per one endpoint.
// Can mix different (dbname, username) connections.
// When running out of free slots for a particular endpoint,
// falls back to opening a new connection for each request.
pub max_conns_per_endpoint: usize,
pub gc_epoch: Duration,
pub pool_shards: usize,
pub idle_timeout: Duration,
pub opt_in: bool,
// Total number of connections in the pool.
pub max_total_conns: usize,
}
impl<C: ClientInnerExt> GlobalConnPool<C> {
pub(crate) fn new(config: &'static crate::config::HttpConfig) -> Arc<Self> {
let shards = config.pool_options.pool_shards;
Arc::new(Self {
global_pool: DashMap::with_shard_amount(shards),
global_pool_size: AtomicUsize::new(0),
config,
global_connections_count: Arc::new(AtomicUsize::new(0)),
})
}
#[cfg(test)]
pub(crate) fn get_global_connections_count(&self) -> usize {
self.global_connections_count
.load(atomic::Ordering::Relaxed)
}
pub(crate) fn get_idle_timeout(&self) -> Duration {
self.config.pool_options.idle_timeout
}
pub(crate) fn shutdown(&self) {
// drops all strong references to endpoint-pools
self.global_pool.clear();
}
pub(crate) async fn gc_worker(&self, mut rng: impl Rng) {
let epoch = self.config.pool_options.gc_epoch;
let mut interval = tokio::time::interval(epoch / (self.global_pool.shards().len()) as u32);
loop {
interval.tick().await;
let shard = rng.gen_range(0..self.global_pool.shards().len());
self.gc(shard);
}
}
fn gc(&self, shard: usize) {
debug!(shard, "pool: performing epoch reclamation");
// acquire a random shard lock
let mut shard = self.global_pool.shards()[shard].write();
let timer = Metrics::get()
.proxy
.http_pool_reclaimation_lag_seconds
.start_timer();
let current_len = shard.len();
let mut clients_removed = 0;
shard.retain(|endpoint, x| {
// if the current endpoint pool is unique (no other strong or weak references)
// then it is currently not in use by any connections.
if let Some(pool) = Arc::get_mut(x.get_mut()) {
let EndpointConnPool {
pools, total_conns, ..
} = pool.get_mut();
// ensure that closed clients are removed
for db_pool in pools.values_mut() {
clients_removed += db_pool.clear_closed_clients(total_conns);
}
// we only remove this pool if it has no active connections
if *total_conns == 0 {
info!("pool: discarding pool for endpoint {endpoint}");
return false;
}
}
true
});
let new_len = shard.len();
drop(shard);
timer.observe();
// Do logging outside of the lock.
if clients_removed > 0 {
let size = self
.global_connections_count
.fetch_sub(clients_removed, atomic::Ordering::Relaxed)
- clients_removed;
Metrics::get()
.proxy
.http_pool_opened_connections
.get_metric()
.dec_by(clients_removed as i64);
info!("pool: performed global pool gc. removed {clients_removed} clients, total number of clients in pool is {size}");
}
let removed = current_len - new_len;
if removed > 0 {
let global_pool_size = self
.global_pool_size
.fetch_sub(removed, atomic::Ordering::Relaxed)
- removed;
info!("pool: performed global pool gc. size now {global_pool_size}");
}
}
pub(crate) fn get(
self: &Arc<Self>,
ctx: &RequestMonitoring,
conn_info: &ConnInfo,
) -> Result<Option<Client<C>>, HttpConnError> {
let mut client: Option<ClientInner<C>> = None;
let Some(endpoint) = conn_info.endpoint_cache_key() else {
return Ok(None);
};
let endpoint_pool = self.get_or_create_endpoint_pool(&endpoint);
if let Some(entry) = endpoint_pool
.write()
.get_conn_entry(conn_info.db_and_user())
{
client = Some(entry.conn);
}
let endpoint_pool = Arc::downgrade(&endpoint_pool);
// ok return cached connection if found and establish a new one otherwise
if let Some(client) = client {
if client.is_closed() {
info!("pool: cached connection '{conn_info}' is closed, opening a new one");
return Ok(None);
}
tracing::Span::current().record("conn_id", tracing::field::display(client.conn_id));
tracing::Span::current().record(
"pid",
tracing::field::display(client.inner.get_process_id()),
);
info!(
cold_start_info = ColdStartInfo::HttpPoolHit.as_str(),
"pool: reusing connection '{conn_info}'"
);
client.session.send(ctx.session_id())?;
ctx.set_cold_start_info(ColdStartInfo::HttpPoolHit);
ctx.success();
return Ok(Some(Client::new(client, conn_info.clone(), endpoint_pool)));
}
Ok(None)
}
fn get_or_create_endpoint_pool(
self: &Arc<Self>,
endpoint: &EndpointCacheKey,
) -> Arc<RwLock<EndpointConnPool<C>>> {
// fast path
if let Some(pool) = self.global_pool.get(endpoint) {
return pool.clone();
}
// slow path
let new_pool = Arc::new(RwLock::new(EndpointConnPool {
pools: HashMap::new(),
total_conns: 0,
max_conns: self.config.pool_options.max_conns_per_endpoint,
_guard: Metrics::get().proxy.http_endpoint_pools.guard(),
global_connections_count: self.global_connections_count.clone(),
global_pool_size_max_conns: self.config.pool_options.max_total_conns,
}));
// find or create a pool for this endpoint
let mut created = false;
let pool = self
.global_pool
.entry(endpoint.clone())
.or_insert_with(|| {
created = true;
new_pool
})
.clone();
// log new global pool size
if created {
let global_pool_size = self
.global_pool_size
.fetch_add(1, atomic::Ordering::Relaxed)
+ 1;
info!(
"pool: created new pool for '{endpoint}', global pool size now {global_pool_size}"
);
}
pool
}
}
pub(crate) fn poll_client<C: ClientInnerExt>(
global_pool: Arc<GlobalConnPool<C>>,
ctx: &RequestMonitoring,
@@ -154,7 +576,7 @@ pub(crate) fn poll_client<C: ClientInnerExt>(
}
.instrument(span));
let inner = ClientInnerRemote {
let inner = ClientInner {
inner: client,
session: tx,
cancel,
@@ -164,7 +586,7 @@ pub(crate) fn poll_client<C: ClientInnerExt>(
Client::new(inner, conn_info, pool_clone)
}
pub(crate) struct ClientInnerRemote<C: ClientInnerExt> {
struct ClientInner<C: ClientInnerExt> {
inner: C,
session: tokio::sync::watch::Sender<uuid::Uuid>,
cancel: CancellationToken,
@@ -172,48 +594,143 @@ pub(crate) struct ClientInnerRemote<C: ClientInnerExt> {
conn_id: uuid::Uuid,
}
impl<C: ClientInnerExt> ClientInnerRemote<C> {
pub(crate) fn inner_mut(&mut self) -> &mut C {
&mut self.inner
}
pub(crate) fn inner(&self) -> &C {
&self.inner
}
pub(crate) fn session(&mut self) -> &mut tokio::sync::watch::Sender<uuid::Uuid> {
&mut self.session
}
pub(crate) fn aux(&self) -> &MetricsAuxInfo {
&self.aux
}
pub(crate) fn get_conn_id(&self) -> uuid::Uuid {
self.conn_id
}
pub(crate) fn is_closed(&self) -> bool {
self.inner.is_closed()
}
}
impl<C: ClientInnerExt> Drop for ClientInnerRemote<C> {
impl<C: ClientInnerExt> Drop for ClientInner<C> {
fn drop(&mut self) {
// on client drop, tell the conn to shut down
self.cancel.cancel();
}
}
pub(crate) trait ClientInnerExt: Sync + Send + 'static {
fn is_closed(&self) -> bool;
fn get_process_id(&self) -> i32;
}
impl ClientInnerExt for tokio_postgres::Client {
fn is_closed(&self) -> bool {
self.is_closed()
}
fn get_process_id(&self) -> i32 {
self.get_process_id()
}
}
impl<C: ClientInnerExt> ClientInner<C> {
pub(crate) fn is_closed(&self) -> bool {
self.inner.is_closed()
}
}
impl<C: ClientInnerExt> Client<C> {
pub(crate) fn metrics(&self) -> Arc<MetricCounter> {
let aux = &self.inner.as_ref().unwrap().aux;
USAGE_METRICS.register(Ids {
endpoint_id: aux.endpoint_id,
branch_id: aux.branch_id,
})
}
}
pub(crate) struct Client<C: ClientInnerExt> {
span: Span,
inner: Option<ClientInner<C>>,
conn_info: ConnInfo,
pool: Weak<RwLock<EndpointConnPool<C>>>,
}
pub(crate) struct Discard<'a, C: ClientInnerExt> {
conn_info: &'a ConnInfo,
pool: &'a mut Weak<RwLock<EndpointConnPool<C>>>,
}
impl<C: ClientInnerExt> Client<C> {
pub(self) fn new(
inner: ClientInner<C>,
conn_info: ConnInfo,
pool: Weak<RwLock<EndpointConnPool<C>>>,
) -> Self {
Self {
inner: Some(inner),
span: Span::current(),
conn_info,
pool,
}
}
pub(crate) fn inner(&mut self) -> (&mut C, Discard<'_, C>) {
let Self {
inner,
pool,
conn_info,
span: _,
} = self;
let inner = inner.as_mut().expect("client inner should not be removed");
(&mut inner.inner, Discard { conn_info, pool })
}
}
impl<C: ClientInnerExt> Discard<'_, C> {
pub(crate) fn check_idle(&mut self, status: ReadyForQueryStatus) {
let conn_info = &self.conn_info;
if status != ReadyForQueryStatus::Idle && std::mem::take(self.pool).strong_count() > 0 {
info!("pool: throwing away connection '{conn_info}' because connection is not idle");
}
}
pub(crate) fn discard(&mut self) {
let conn_info = &self.conn_info;
if std::mem::take(self.pool).strong_count() > 0 {
info!("pool: throwing away connection '{conn_info}' because connection is potentially in a broken state");
}
}
}
impl<C: ClientInnerExt> Deref for Client<C> {
type Target = C;
fn deref(&self) -> &Self::Target {
&self
.inner
.as_ref()
.expect("client inner should not be removed")
.inner
}
}
impl<C: ClientInnerExt> Client<C> {
fn do_drop(&mut self) -> Option<impl FnOnce()> {
let conn_info = self.conn_info.clone();
let client = self
.inner
.take()
.expect("client inner should not be removed");
if let Some(conn_pool) = std::mem::take(&mut self.pool).upgrade() {
let current_span = self.span.clone();
// return connection to the pool
return Some(move || {
let _span = current_span.enter();
EndpointConnPool::put(&conn_pool, &conn_info, client);
});
}
None
}
}
impl<C: ClientInnerExt> Drop for Client<C> {
fn drop(&mut self) {
if let Some(drop) = self.do_drop() {
tokio::task::spawn_blocking(drop);
}
}
}
#[cfg(test)]
mod tests {
use std::mem;
use std::sync::atomic::AtomicBool;
use std::{mem, sync::atomic::AtomicBool};
use crate::{
proxy::NeonOptions, serverless::cancel_set::CancelSet, BranchId, EndpointId, ProjectId,
};
use super::*;
use crate::proxy::NeonOptions;
use crate::serverless::cancel_set::CancelSet;
use crate::{BranchId, EndpointId, ProjectId};
struct MockClient(Arc<AtomicBool>);
impl MockClient {
@@ -230,12 +747,12 @@ mod tests {
}
}
fn create_inner() -> ClientInnerRemote<MockClient> {
fn create_inner() -> ClientInner<MockClient> {
create_inner_with(MockClient::new(false))
}
fn create_inner_with(client: MockClient) -> ClientInnerRemote<MockClient> {
ClientInnerRemote {
fn create_inner_with(client: MockClient) -> ClientInner<MockClient> {
ClientInner {
inner: client,
session: tokio::sync::watch::Sender::new(uuid::Uuid::new_v4()),
cancel: CancellationToken::new(),
@@ -282,7 +799,7 @@ mod tests {
{
let mut client = Client::new(create_inner(), conn_info.clone(), ep_pool.clone());
assert_eq!(0, pool.get_global_connections_count());
client.inner_mut().1.discard();
client.inner().1.discard();
// Discard should not add the connection from the pool.
assert_eq!(0, pool.get_global_connections_count());
}

View File

@@ -1,562 +0,0 @@
use dashmap::DashMap;
use parking_lot::RwLock;
use rand::Rng;
use std::{collections::HashMap, sync::Arc, sync::Weak, time::Duration};
use std::{
ops::Deref,
sync::atomic::{self, AtomicUsize},
};
use tokio_postgres::ReadyForQueryStatus;
use crate::control_plane::messages::ColdStartInfo;
use crate::metrics::{HttpEndpointPoolsGuard, Metrics};
use crate::usage_metrics::{Ids, MetricCounter, USAGE_METRICS};
use crate::{
auth::backend::ComputeUserInfo, context::RequestMonitoring, DbName, EndpointCacheKey, RoleName,
};
use super::conn_pool::ClientInnerRemote;
use tracing::info;
use tracing::{debug, Span};
use super::backend::HttpConnError;
#[derive(Debug, Clone)]
pub(crate) struct ConnInfo {
pub(crate) user_info: ComputeUserInfo,
pub(crate) dbname: DbName,
}
impl ConnInfo {
// hm, change to hasher to avoid cloning?
pub(crate) fn db_and_user(&self) -> (DbName, RoleName) {
(self.dbname.clone(), self.user_info.user.clone())
}
pub(crate) fn endpoint_cache_key(&self) -> Option<EndpointCacheKey> {
// We don't want to cache http connections for ephemeral endpoints.
if self.user_info.options.is_ephemeral() {
None
} else {
Some(self.user_info.endpoint_cache_key())
}
}
}
pub(crate) struct ConnPoolEntry<C: ClientInnerExt> {
pub(crate) conn: ClientInnerRemote<C>,
pub(crate) _last_access: std::time::Instant,
}
// Per-endpoint connection pool, (dbname, username) -> DbUserConnPool
// Number of open connections is limited by the `max_conns_per_endpoint`.
pub(crate) struct EndpointConnPool<C: ClientInnerExt> {
pools: HashMap<(DbName, RoleName), DbUserConnPool<C>>,
total_conns: usize,
max_conns: usize,
_guard: HttpEndpointPoolsGuard<'static>,
global_connections_count: Arc<AtomicUsize>,
global_pool_size_max_conns: usize,
}
impl<C: ClientInnerExt> EndpointConnPool<C> {
fn get_conn_entry(&mut self, db_user: (DbName, RoleName)) -> Option<ConnPoolEntry<C>> {
let Self {
pools,
total_conns,
global_connections_count,
..
} = self;
pools.get_mut(&db_user).and_then(|pool_entries| {
let (entry, removed) = pool_entries.get_conn_entry(total_conns);
global_connections_count.fetch_sub(removed, atomic::Ordering::Relaxed);
entry
})
}
pub(crate) fn remove_client(
&mut self,
db_user: (DbName, RoleName),
conn_id: uuid::Uuid,
) -> bool {
let Self {
pools,
total_conns,
global_connections_count,
..
} = self;
if let Some(pool) = pools.get_mut(&db_user) {
let old_len = pool.conns.len();
pool.conns.retain(|conn| conn.conn.get_conn_id() != conn_id);
let new_len = pool.conns.len();
let removed = old_len - new_len;
if removed > 0 {
global_connections_count.fetch_sub(removed, atomic::Ordering::Relaxed);
Metrics::get()
.proxy
.http_pool_opened_connections
.get_metric()
.dec_by(removed as i64);
}
*total_conns -= removed;
removed > 0
} else {
false
}
}
pub(crate) fn put(pool: &RwLock<Self>, conn_info: &ConnInfo, client: ClientInnerRemote<C>) {
let conn_id = client.get_conn_id();
if client.is_closed() {
info!(%conn_id, "pool: throwing away connection '{conn_info}' because connection is closed");
return;
}
let global_max_conn = pool.read().global_pool_size_max_conns;
if pool
.read()
.global_connections_count
.load(atomic::Ordering::Relaxed)
>= global_max_conn
{
info!(%conn_id, "pool: throwing away connection '{conn_info}' because pool is full");
return;
}
// return connection to the pool
let mut returned = false;
let mut per_db_size = 0;
let total_conns = {
let mut pool = pool.write();
if pool.total_conns < pool.max_conns {
let pool_entries = pool.pools.entry(conn_info.db_and_user()).or_default();
pool_entries.conns.push(ConnPoolEntry {
conn: client,
_last_access: std::time::Instant::now(),
});
returned = true;
per_db_size = pool_entries.conns.len();
pool.total_conns += 1;
pool.global_connections_count
.fetch_add(1, atomic::Ordering::Relaxed);
Metrics::get()
.proxy
.http_pool_opened_connections
.get_metric()
.inc();
}
pool.total_conns
};
// do logging outside of the mutex
if returned {
info!(%conn_id, "pool: returning connection '{conn_info}' back to the pool, total_conns={total_conns}, for this (db, user)={per_db_size}");
} else {
info!(%conn_id, "pool: throwing away connection '{conn_info}' because pool is full, total_conns={total_conns}");
}
}
}
impl<C: ClientInnerExt> Drop for EndpointConnPool<C> {
fn drop(&mut self) {
if self.total_conns > 0 {
self.global_connections_count
.fetch_sub(self.total_conns, atomic::Ordering::Relaxed);
Metrics::get()
.proxy
.http_pool_opened_connections
.get_metric()
.dec_by(self.total_conns as i64);
}
}
}
pub(crate) struct DbUserConnPool<C: ClientInnerExt> {
pub(crate) conns: Vec<ConnPoolEntry<C>>,
}
impl<C: ClientInnerExt> Default for DbUserConnPool<C> {
fn default() -> Self {
Self { conns: Vec::new() }
}
}
impl<C: ClientInnerExt> DbUserConnPool<C> {
fn clear_closed_clients(&mut self, conns: &mut usize) -> usize {
let old_len = self.conns.len();
self.conns.retain(|conn| !conn.conn.is_closed());
let new_len = self.conns.len();
let removed = old_len - new_len;
*conns -= removed;
removed
}
pub(crate) fn get_conn_entry(
&mut self,
conns: &mut usize,
) -> (Option<ConnPoolEntry<C>>, usize) {
let mut removed = self.clear_closed_clients(conns);
let conn = self.conns.pop();
if conn.is_some() {
*conns -= 1;
removed += 1;
}
Metrics::get()
.proxy
.http_pool_opened_connections
.get_metric()
.dec_by(removed as i64);
(conn, removed)
}
}
pub(crate) struct GlobalConnPool<C: ClientInnerExt> {
// endpoint -> per-endpoint connection pool
//
// That should be a fairly conteded map, so return reference to the per-endpoint
// pool as early as possible and release the lock.
global_pool: DashMap<EndpointCacheKey, Arc<RwLock<EndpointConnPool<C>>>>,
/// Number of endpoint-connection pools
///
/// [`DashMap::len`] iterates over all inner pools and acquires a read lock on each.
/// That seems like far too much effort, so we're using a relaxed increment counter instead.
/// It's only used for diagnostics.
global_pool_size: AtomicUsize,
/// Total number of connections in the pool
global_connections_count: Arc<AtomicUsize>,
config: &'static crate::config::HttpConfig,
}
#[derive(Debug, Clone, Copy)]
pub struct GlobalConnPoolOptions {
// Maximum number of connections per one endpoint.
// Can mix different (dbname, username) connections.
// When running out of free slots for a particular endpoint,
// falls back to opening a new connection for each request.
pub max_conns_per_endpoint: usize,
pub gc_epoch: Duration,
pub pool_shards: usize,
pub idle_timeout: Duration,
pub opt_in: bool,
// Total number of connections in the pool.
pub max_total_conns: usize,
}
impl<C: ClientInnerExt> GlobalConnPool<C> {
pub(crate) fn new(config: &'static crate::config::HttpConfig) -> Arc<Self> {
let shards = config.pool_options.pool_shards;
Arc::new(Self {
global_pool: DashMap::with_shard_amount(shards),
global_pool_size: AtomicUsize::new(0),
config,
global_connections_count: Arc::new(AtomicUsize::new(0)),
})
}
#[cfg(test)]
pub(crate) fn get_global_connections_count(&self) -> usize {
self.global_connections_count
.load(atomic::Ordering::Relaxed)
}
pub(crate) fn get_idle_timeout(&self) -> Duration {
self.config.pool_options.idle_timeout
}
pub(crate) fn shutdown(&self) {
// drops all strong references to endpoint-pools
self.global_pool.clear();
}
pub(crate) async fn gc_worker(&self, mut rng: impl Rng) {
let epoch = self.config.pool_options.gc_epoch;
let mut interval = tokio::time::interval(epoch / (self.global_pool.shards().len()) as u32);
loop {
interval.tick().await;
let shard = rng.gen_range(0..self.global_pool.shards().len());
self.gc(shard);
}
}
pub(crate) fn gc(&self, shard: usize) {
debug!(shard, "pool: performing epoch reclamation");
// acquire a random shard lock
let mut shard = self.global_pool.shards()[shard].write();
let timer = Metrics::get()
.proxy
.http_pool_reclaimation_lag_seconds
.start_timer();
let current_len = shard.len();
let mut clients_removed = 0;
shard.retain(|endpoint, x| {
// if the current endpoint pool is unique (no other strong or weak references)
// then it is currently not in use by any connections.
if let Some(pool) = Arc::get_mut(x.get_mut()) {
let EndpointConnPool {
pools, total_conns, ..
} = pool.get_mut();
// ensure that closed clients are removed
for db_pool in pools.values_mut() {
clients_removed += db_pool.clear_closed_clients(total_conns);
}
// we only remove this pool if it has no active connections
if *total_conns == 0 {
info!("pool: discarding pool for endpoint {endpoint}");
return false;
}
}
true
});
let new_len = shard.len();
drop(shard);
timer.observe();
// Do logging outside of the lock.
if clients_removed > 0 {
let size = self
.global_connections_count
.fetch_sub(clients_removed, atomic::Ordering::Relaxed)
- clients_removed;
Metrics::get()
.proxy
.http_pool_opened_connections
.get_metric()
.dec_by(clients_removed as i64);
info!("pool: performed global pool gc. removed {clients_removed} clients, total number of clients in pool is {size}");
}
let removed = current_len - new_len;
if removed > 0 {
let global_pool_size = self
.global_pool_size
.fetch_sub(removed, atomic::Ordering::Relaxed)
- removed;
info!("pool: performed global pool gc. size now {global_pool_size}");
}
}
pub(crate) fn get_or_create_endpoint_pool(
self: &Arc<Self>,
endpoint: &EndpointCacheKey,
) -> Arc<RwLock<EndpointConnPool<C>>> {
// fast path
if let Some(pool) = self.global_pool.get(endpoint) {
return pool.clone();
}
// slow path
let new_pool = Arc::new(RwLock::new(EndpointConnPool {
pools: HashMap::new(),
total_conns: 0,
max_conns: self.config.pool_options.max_conns_per_endpoint,
_guard: Metrics::get().proxy.http_endpoint_pools.guard(),
global_connections_count: self.global_connections_count.clone(),
global_pool_size_max_conns: self.config.pool_options.max_total_conns,
}));
// find or create a pool for this endpoint
let mut created = false;
let pool = self
.global_pool
.entry(endpoint.clone())
.or_insert_with(|| {
created = true;
new_pool
})
.clone();
// log new global pool size
if created {
let global_pool_size = self
.global_pool_size
.fetch_add(1, atomic::Ordering::Relaxed)
+ 1;
info!(
"pool: created new pool for '{endpoint}', global pool size now {global_pool_size}"
);
}
pool
}
pub(crate) fn get(
self: &Arc<Self>,
ctx: &RequestMonitoring,
conn_info: &ConnInfo,
) -> Result<Option<Client<C>>, HttpConnError> {
let mut client: Option<ClientInnerRemote<C>> = None;
let Some(endpoint) = conn_info.endpoint_cache_key() else {
return Ok(None);
};
let endpoint_pool = self.get_or_create_endpoint_pool(&endpoint);
if let Some(entry) = endpoint_pool
.write()
.get_conn_entry(conn_info.db_and_user())
{
client = Some(entry.conn);
}
let endpoint_pool = Arc::downgrade(&endpoint_pool);
// ok return cached connection if found and establish a new one otherwise
if let Some(mut client) = client {
if client.is_closed() {
info!("pool: cached connection '{conn_info}' is closed, opening a new one");
return Ok(None);
}
tracing::Span::current()
.record("conn_id", tracing::field::display(client.get_conn_id()));
tracing::Span::current().record(
"pid",
tracing::field::display(client.inner().get_process_id()),
);
info!(
cold_start_info = ColdStartInfo::HttpPoolHit.as_str(),
"pool: reusing connection '{conn_info}'"
);
client.session().send(ctx.session_id())?;
ctx.set_cold_start_info(ColdStartInfo::HttpPoolHit);
ctx.success();
return Ok(Some(Client::new(client, conn_info.clone(), endpoint_pool)));
}
Ok(None)
}
}
impl<C: ClientInnerExt> Client<C> {
pub(crate) fn new(
inner: ClientInnerRemote<C>,
conn_info: ConnInfo,
pool: Weak<RwLock<EndpointConnPool<C>>>,
) -> Self {
Self {
inner: Some(inner),
span: Span::current(),
conn_info,
pool,
}
}
pub(crate) fn inner_mut(&mut self) -> (&mut C, Discard<'_, C>) {
let Self {
inner,
pool,
conn_info,
span: _,
} = self;
let inner = inner.as_mut().expect("client inner should not be removed");
let inner_ref = inner.inner_mut();
(inner_ref, Discard { conn_info, pool })
}
pub(crate) fn metrics(&self) -> Arc<MetricCounter> {
let aux = &self.inner.as_ref().unwrap().aux();
USAGE_METRICS.register(Ids {
endpoint_id: aux.endpoint_id,
branch_id: aux.branch_id,
})
}
pub(crate) fn do_drop(&mut self) -> Option<impl FnOnce()> {
let conn_info = self.conn_info.clone();
let client = self
.inner
.take()
.expect("client inner should not be removed");
if let Some(conn_pool) = std::mem::take(&mut self.pool).upgrade() {
let current_span = self.span.clone();
// return connection to the pool
return Some(move || {
let _span = current_span.enter();
EndpointConnPool::put(&conn_pool, &conn_info, client);
});
}
None
}
}
pub(crate) struct Client<C: ClientInnerExt> {
span: Span,
inner: Option<ClientInnerRemote<C>>,
conn_info: ConnInfo,
pool: Weak<RwLock<EndpointConnPool<C>>>,
}
impl<C: ClientInnerExt> Drop for Client<C> {
fn drop(&mut self) {
if let Some(drop) = self.do_drop() {
tokio::task::spawn_blocking(drop);
}
}
}
impl<C: ClientInnerExt> Deref for Client<C> {
type Target = C;
fn deref(&self) -> &Self::Target {
self.inner
.as_ref()
.expect("client inner should not be removed")
.inner()
}
}
pub(crate) trait ClientInnerExt: Sync + Send + 'static {
fn is_closed(&self) -> bool;
fn get_process_id(&self) -> i32;
}
impl ClientInnerExt for tokio_postgres::Client {
fn is_closed(&self) -> bool {
self.is_closed()
}
fn get_process_id(&self) -> i32 {
self.get_process_id()
}
}
pub(crate) struct Discard<'a, C: ClientInnerExt> {
conn_info: &'a ConnInfo,
pool: &'a mut Weak<RwLock<EndpointConnPool<C>>>,
}
impl<C: ClientInnerExt> Discard<'_, C> {
pub(crate) fn check_idle(&mut self, status: ReadyForQueryStatus) {
let conn_info = &self.conn_info;
if status != ReadyForQueryStatus::Idle && std::mem::take(self.pool).strong_count() > 0 {
info!("pool: throwing away connection '{conn_info}' because connection is not idle");
}
}
pub(crate) fn discard(&mut self) {
let conn_info = &self.conn_info;
if std::mem::take(self.pool).strong_count() > 0 {
info!("pool: throwing away connection '{conn_info}' because connection is potentially in a broken state");
}
}
}

View File

@@ -1,37 +1,37 @@
use std::collections::VecDeque;
use std::sync::atomic::{self, AtomicUsize};
use std::sync::{Arc, Weak};
use dashmap::DashMap;
use hyper::client::conn::http2;
use hyper_util::rt::{TokioExecutor, TokioIo};
use parking_lot::RwLock;
use rand::Rng;
use std::collections::VecDeque;
use std::sync::atomic::{self, AtomicUsize};
use std::{sync::Arc, sync::Weak};
use tokio::net::TcpStream;
use tracing::{debug, error, info, info_span, Instrument};
use crate::context::RequestMonitoring;
use crate::control_plane::messages::{ColdStartInfo, MetricsAuxInfo};
use crate::metrics::{HttpEndpointPoolsGuard, Metrics};
use crate::usage_metrics::{Ids, MetricCounter, USAGE_METRICS};
use crate::{context::RequestMonitoring, EndpointCacheKey};
use super::conn_pool_lib::{ClientInnerExt, ConnInfo};
use crate::EndpointCacheKey;
use tracing::{debug, error};
use tracing::{info, info_span, Instrument};
use super::conn_pool::ConnInfo;
pub(crate) type Send = http2::SendRequest<hyper::body::Incoming>;
pub(crate) type Connect =
http2::Connection<TokioIo<TcpStream>, hyper::body::Incoming, TokioExecutor>;
#[derive(Clone)]
pub(crate) struct ConnPoolEntry<C: ClientInnerExt + Clone> {
conn: C,
struct ConnPoolEntry {
conn: Send,
conn_id: uuid::Uuid,
aux: MetricsAuxInfo,
}
// Per-endpoint connection pool
// Number of open connections is limited by the `max_conns_per_endpoint`.
pub(crate) struct EndpointConnPool<C: ClientInnerExt + Clone> {
pub(crate) struct EndpointConnPool {
// TODO(conrad):
// either we should open more connections depending on stream count
// (not exposed by hyper, need our own counter)
@@ -41,13 +41,13 @@ pub(crate) struct EndpointConnPool<C: ClientInnerExt + Clone> {
// seems somewhat redundant though.
//
// Probably we should run a semaphore and just the single conn. TBD.
conns: VecDeque<ConnPoolEntry<C>>,
conns: VecDeque<ConnPoolEntry>,
_guard: HttpEndpointPoolsGuard<'static>,
global_connections_count: Arc<AtomicUsize>,
}
impl<C: ClientInnerExt + Clone> EndpointConnPool<C> {
fn get_conn_entry(&mut self) -> Option<ConnPoolEntry<C>> {
impl EndpointConnPool {
fn get_conn_entry(&mut self) -> Option<ConnPoolEntry> {
let Self { conns, .. } = self;
loop {
@@ -82,7 +82,7 @@ impl<C: ClientInnerExt + Clone> EndpointConnPool<C> {
}
}
impl<C: ClientInnerExt + Clone> Drop for EndpointConnPool<C> {
impl Drop for EndpointConnPool {
fn drop(&mut self) {
if !self.conns.is_empty() {
self.global_connections_count
@@ -96,12 +96,12 @@ impl<C: ClientInnerExt + Clone> Drop for EndpointConnPool<C> {
}
}
pub(crate) struct GlobalConnPool<C: ClientInnerExt + Clone> {
pub(crate) struct GlobalConnPool {
// endpoint -> per-endpoint connection pool
//
// That should be a fairly conteded map, so return reference to the per-endpoint
// pool as early as possible and release the lock.
global_pool: DashMap<EndpointCacheKey, Arc<RwLock<EndpointConnPool<C>>>>,
global_pool: DashMap<EndpointCacheKey, Arc<RwLock<EndpointConnPool>>>,
/// Number of endpoint-connection pools
///
@@ -116,7 +116,7 @@ pub(crate) struct GlobalConnPool<C: ClientInnerExt + Clone> {
config: &'static crate::config::HttpConfig,
}
impl<C: ClientInnerExt + Clone> GlobalConnPool<C> {
impl GlobalConnPool {
pub(crate) fn new(config: &'static crate::config::HttpConfig) -> Arc<Self> {
let shards = config.pool_options.pool_shards;
Arc::new(Self {
@@ -211,7 +211,7 @@ impl<C: ClientInnerExt + Clone> GlobalConnPool<C> {
self: &Arc<Self>,
ctx: &RequestMonitoring,
conn_info: &ConnInfo,
) -> Option<Client<C>> {
) -> Option<Client> {
let endpoint = conn_info.endpoint_cache_key()?;
let endpoint_pool = self.get_or_create_endpoint_pool(&endpoint);
let client = endpoint_pool.write().get_conn_entry()?;
@@ -229,7 +229,7 @@ impl<C: ClientInnerExt + Clone> GlobalConnPool<C> {
fn get_or_create_endpoint_pool(
self: &Arc<Self>,
endpoint: &EndpointCacheKey,
) -> Arc<RwLock<EndpointConnPool<C>>> {
) -> Arc<RwLock<EndpointConnPool>> {
// fast path
if let Some(pool) = self.global_pool.get(endpoint) {
return pool.clone();
@@ -269,14 +269,14 @@ impl<C: ClientInnerExt + Clone> GlobalConnPool<C> {
}
pub(crate) fn poll_http2_client(
global_pool: Arc<GlobalConnPool<Send>>,
global_pool: Arc<GlobalConnPool>,
ctx: &RequestMonitoring,
conn_info: &ConnInfo,
client: Send,
connection: Connect,
conn_id: uuid::Uuid,
aux: MetricsAuxInfo,
) -> Client<Send> {
) -> Client {
let conn_gauge = Metrics::get().proxy.db_connections.guard(ctx.protocol());
let session_id = ctx.session_id();
@@ -323,13 +323,13 @@ pub(crate) fn poll_http2_client(
Client::new(client, aux)
}
pub(crate) struct Client<C: ClientInnerExt + Clone> {
pub(crate) inner: C,
pub(crate) struct Client {
pub(crate) inner: Send,
aux: MetricsAuxInfo,
}
impl<C: ClientInnerExt + Clone> Client<C> {
pub(self) fn new(inner: C, aux: MetricsAuxInfo) -> Self {
impl Client {
pub(self) fn new(inner: Send, aux: MetricsAuxInfo) -> Self {
Self { inner, aux }
}
@@ -340,14 +340,3 @@ impl<C: ClientInnerExt + Clone> Client<C> {
})
}
}
impl ClientInnerExt for Send {
fn is_closed(&self) -> bool {
self.is_closed()
}
fn get_process_id(&self) -> i32 {
// ideally throw something meaningful
-1
}
}

View File

@@ -1,11 +1,12 @@
//! Things stolen from `libs/utils/src/http` to add hyper 1.0 compatibility
//! Will merge back in at some point in the future.
use anyhow::Context;
use bytes::Bytes;
use anyhow::Context;
use http::{Response, StatusCode};
use http_body_util::combinators::BoxBody;
use http_body_util::{BodyExt, Full};
use http_body_util::{combinators::BoxBody, BodyExt, Full};
use serde::Serialize;
use utils::http::error::ApiError;

View File

@@ -1,5 +1,7 @@
use serde_json::{Map, Value};
use tokio_postgres::types::{Kind, Type};
use serde_json::Map;
use serde_json::Value;
use tokio_postgres::types::Kind;
use tokio_postgres::types::Type;
use tokio_postgres::Row;
//
@@ -254,9 +256,8 @@ fn _pg_array_parse(
#[cfg(test)]
mod tests {
use serde_json::json;
use super::*;
use serde_json::json;
#[test]
fn test_atomic_types_to_pg_params() {

View File

@@ -1,53 +1,37 @@
//! Manages the pool of connections between local_proxy and postgres.
//!
//! The pool is keyed by database and role_name, and can contain multiple connections
//! shared between users.
//!
//! The pool manages the pg_session_jwt extension used for authorizing
//! requests in the db.
//!
//! The first time a db/role pair is seen, local_proxy attempts to install the extension
//! and grant usage to the role on the given schema.
use std::collections::HashMap;
use std::pin::pin;
use std::sync::{Arc, Weak};
use std::task::{ready, Poll};
use std::time::Duration;
use futures::future::poll_fn;
use futures::Future;
use futures::{future::poll_fn, Future};
use indexmap::IndexMap;
use jose_jwk::jose_b64::base64ct::{Base64UrlUnpadded, Encoding};
use p256::ecdsa::{Signature, SigningKey};
use parking_lot::RwLock;
use serde_json::value::RawValue;
use signature::Signer;
use std::task::{ready, Poll};
use std::{collections::HashMap, pin::pin, sync::Arc, sync::Weak, time::Duration};
use tokio::time::Instant;
use tokio_postgres::tls::NoTlsStream;
use tokio_postgres::types::ToSql;
use tokio_postgres::{AsyncMessage, ReadyForQueryStatus, Socket};
use tokio_util::sync::CancellationToken;
use tracing::{error, info, info_span, warn, Instrument, Span};
use super::backend::HttpConnError;
use super::conn_pool_lib::{ClientInnerExt, ConnInfo};
use crate::context::RequestMonitoring;
use crate::control_plane::messages::{ColdStartInfo, MetricsAuxInfo};
use crate::metrics::Metrics;
use crate::usage_metrics::{Ids, MetricCounter, USAGE_METRICS};
use crate::{context::RequestMonitoring, DbName, RoleName};
use crate::{DbName, RoleName};
use tracing::{error, warn, Span};
use tracing::{info, info_span, Instrument};
pub(crate) const EXT_NAME: &str = "pg_session_jwt";
pub(crate) const EXT_VERSION: &str = "0.1.1";
pub(crate) const EXT_SCHEMA: &str = "auth";
use super::backend::HttpConnError;
use super::conn_pool::{ClientInnerExt, ConnInfo};
struct ConnPoolEntry<C: ClientInnerExt> {
conn: ClientInner<C>,
_last_access: std::time::Instant,
}
// /// key id for the pg_session_jwt state
// static PG_SESSION_JWT_KID: AtomicU64 = AtomicU64::new(1);
// Per-endpoint connection pool, (dbname, username) -> DbUserConnPool
// Number of open connections is limited by the `max_conns_per_endpoint`.
pub(crate) struct EndpointConnPool<C: ClientInnerExt> {
@@ -153,18 +137,11 @@ impl<C: ClientInnerExt> Drop for EndpointConnPool<C> {
pub(crate) struct DbUserConnPool<C: ClientInnerExt> {
conns: Vec<ConnPoolEntry<C>>,
// true if we have definitely installed the extension and
// granted the role access to the auth schema.
initialized: bool,
}
impl<C: ClientInnerExt> Default for DbUserConnPool<C> {
fn default() -> Self {
Self {
conns: Vec::new(),
initialized: false,
}
Self { conns: Vec::new() }
}
}
@@ -219,16 +196,25 @@ impl<C: ClientInnerExt> LocalConnPool<C> {
self.config.pool_options.idle_timeout
}
// pub(crate) fn shutdown(&self) {
// let mut pool = self.global_pool.write();
// pool.pools.clear();
// pool.total_conns = 0;
// }
pub(crate) fn get(
self: &Arc<Self>,
ctx: &RequestMonitoring,
conn_info: &ConnInfo,
) -> Result<Option<LocalClient<C>>, HttpConnError> {
let client = self
let mut client: Option<ClientInner<C>> = None;
if let Some(entry) = self
.global_pool
.write()
.get_conn_entry(conn_info.db_and_user())
.map(|entry| entry.conn);
{
client = Some(entry.conn);
}
// ok return cached connection if found and establish a new one otherwise
if let Some(client) = client {
@@ -256,23 +242,6 @@ impl<C: ClientInnerExt> LocalConnPool<C> {
}
Ok(None)
}
pub(crate) fn initialized(self: &Arc<Self>, conn_info: &ConnInfo) -> bool {
self.global_pool
.read()
.pools
.get(&conn_info.db_and_user())
.map_or(false, |pool| pool.initialized)
}
pub(crate) fn set_initialized(self: &Arc<Self>, conn_info: &ConnInfo) {
self.global_pool
.write()
.pools
.entry(conn_info.db_and_user())
.or_default()
.initialized = true;
}
}
#[allow(clippy::too_many_arguments)]
@@ -390,7 +359,7 @@ pub(crate) fn poll_client(
LocalClient::new(inner, conn_info, pool_clone)
}
pub(crate) struct ClientInner<C: ClientInnerExt> {
struct ClientInner<C: ClientInnerExt> {
inner: C,
session: tokio::sync::watch::Sender<uuid::Uuid>,
cancel: CancellationToken,
@@ -415,24 +384,13 @@ impl<C: ClientInnerExt> ClientInner<C> {
}
}
impl ClientInner<tokio_postgres::Client> {
pub(crate) async fn set_jwt_session(&mut self, payload: &[u8]) -> Result<(), HttpConnError> {
self.jti += 1;
let token = resign_jwt(&self.key, payload, self.jti)?;
// initiates the auth session
self.inner.simple_query("discard all").await?;
self.inner
.query(
"select auth.jwt_session_init($1)",
&[&token as &(dyn ToSql + Sync)],
)
.await?;
let pid = self.inner.get_process_id();
info!(pid, jti = self.jti, "user session state init");
Ok(())
impl<C: ClientInnerExt> LocalClient<C> {
pub(crate) fn metrics(&self) -> Arc<MetricCounter> {
let aux = &self.inner.as_ref().unwrap().aux;
USAGE_METRICS.register(Ids {
endpoint_id: aux.endpoint_id,
branch_id: aux.branch_id,
})
}
}
@@ -461,18 +419,6 @@ impl<C: ClientInnerExt> LocalClient<C> {
pool,
}
}
pub(crate) fn client_inner(&mut self) -> (&mut ClientInner<C>, Discard<'_, C>) {
let Self {
inner,
pool,
conn_info,
span: _,
} = self;
let inner_m = inner.as_mut().expect("client inner should not be removed");
(inner_m, Discard { conn_info, pool })
}
pub(crate) fn inner(&mut self) -> (&mut C, Discard<'_, C>) {
let Self {
inner,
@@ -485,6 +431,33 @@ impl<C: ClientInnerExt> LocalClient<C> {
}
}
impl LocalClient<tokio_postgres::Client> {
pub(crate) async fn set_jwt_session(&mut self, payload: &[u8]) -> Result<(), HttpConnError> {
let inner = self
.inner
.as_mut()
.expect("client inner should not be removed");
inner.jti += 1;
let token = resign_jwt(&inner.key, payload, inner.jti)?;
// initiates the auth session
inner.inner.simple_query("discard all").await?;
inner
.inner
.query(
"select auth.jwt_session_init($1)",
&[&token as &(dyn ToSql + Sync)],
)
.await?;
let pid = inner.inner.get_process_id();
info!(pid, jti = inner.jti, "user session state init");
Ok(())
}
}
/// implements relatively efficient in-place json object key upserting
///
/// only supports top-level keys
@@ -548,15 +521,24 @@ fn sign_jwt(sk: &SigningKey, payload: &[u8]) -> String {
jwt
}
impl<C: ClientInnerExt> LocalClient<C> {
pub(crate) fn metrics(&self) -> Arc<MetricCounter> {
let aux = &self.inner.as_ref().unwrap().aux;
USAGE_METRICS.register(Ids {
endpoint_id: aux.endpoint_id,
branch_id: aux.branch_id,
})
impl<C: ClientInnerExt> Discard<'_, C> {
pub(crate) fn check_idle(&mut self, status: ReadyForQueryStatus) {
let conn_info = &self.conn_info;
if status != ReadyForQueryStatus::Idle && std::mem::take(self.pool).strong_count() > 0 {
info!(
"local_pool: throwing away connection '{conn_info}' because connection is not idle"
);
}
}
pub(crate) fn discard(&mut self) {
let conn_info = &self.conn_info;
if std::mem::take(self.pool).strong_count() > 0 {
info!("local_pool: throwing away connection '{conn_info}' because connection is potentially in a broken state");
}
}
}
impl<C: ClientInnerExt> LocalClient<C> {
fn do_drop(&mut self) -> Option<impl FnOnce()> {
let conn_info = self.conn_info.clone();
let client = self
@@ -583,23 +565,6 @@ impl<C: ClientInnerExt> Drop for LocalClient<C> {
}
}
impl<C: ClientInnerExt> Discard<'_, C> {
pub(crate) fn check_idle(&mut self, status: ReadyForQueryStatus) {
let conn_info = &self.conn_info;
if status != ReadyForQueryStatus::Idle && std::mem::take(self.pool).strong_count() > 0 {
info!(
"local_pool: throwing away connection '{conn_info}' because connection is not idle"
);
}
}
pub(crate) fn discard(&mut self) {
let conn_info = &self.conn_info;
if std::mem::take(self.pool).strong_count() > 0 {
info!("local_pool: throwing away connection '{conn_info}' because connection is potentially in a broken state");
}
}
}
#[cfg(test)]
mod tests {
use p256::ecdsa::SigningKey;

View File

@@ -5,7 +5,6 @@
mod backend;
pub mod cancel_set;
mod conn_pool;
mod conn_pool_lib;
mod http_conn_pool;
mod http_util;
mod json;
@@ -13,15 +12,12 @@ mod local_conn_pool;
mod sql_over_http;
mod websocket;
use std::net::{IpAddr, SocketAddr};
use std::pin::{pin, Pin};
use std::sync::Arc;
use anyhow::Context;
use async_trait::async_trait;
use atomic_take::AtomicTake;
use bytes::Bytes;
pub use conn_pool_lib::GlobalConnPoolOptions;
pub use conn_pool::GlobalConnPoolOptions;
use anyhow::Context;
use futures::future::{select, Either};
use futures::TryFutureExt;
use http::{Method, Response, StatusCode};
@@ -33,13 +29,9 @@ use hyper_util::server::conn::auto::Builder;
use rand::rngs::StdRng;
use rand::SeedableRng;
use tokio::io::{AsyncRead, AsyncWrite};
use tokio::net::{TcpListener, TcpStream};
use tokio::time::timeout;
use tokio_rustls::TlsAcceptor;
use tokio_util::sync::CancellationToken;
use tokio_util::task::TaskTracker;
use tracing::{info, warn, Instrument};
use utils::http::error::ApiError;
use crate::cancellation::CancellationHandlerMain;
use crate::config::ProxyConfig;
@@ -51,6 +43,14 @@ use crate::rate_limiter::EndpointRateLimiter;
use crate::serverless::backend::PoolingBackend;
use crate::serverless::http_util::{api_error_into_response, json_response};
use std::net::{IpAddr, SocketAddr};
use std::pin::{pin, Pin};
use std::sync::Arc;
use tokio::net::{TcpListener, TcpStream};
use tokio_util::sync::CancellationToken;
use tracing::{info, warn, Instrument};
use utils::http::error::ApiError;
pub(crate) const SERVERLESS_DRIVER_SNI: &str = "api";
pub async fn task_main(
@@ -66,7 +66,7 @@ pub async fn task_main(
}
let local_pool = local_conn_pool::LocalConnPool::new(&config.http_config);
let conn_pool = conn_pool_lib::GlobalConnPool::new(&config.http_config);
let conn_pool = conn_pool::GlobalConnPool::new(&config.http_config);
{
let conn_pool = Arc::clone(&conn_pool);
tokio::spawn(async move {

View File

@@ -2,45 +2,77 @@ use std::pin::pin;
use std::sync::Arc;
use bytes::Bytes;
use futures::future::{select, try_join, Either};
use futures::{StreamExt, TryFutureExt};
use futures::future::select;
use futures::future::try_join;
use futures::future::Either;
use futures::StreamExt;
use futures::TryFutureExt;
use http::header::AUTHORIZATION;
use http::Method;
use http_body_util::combinators::BoxBody;
use http_body_util::{BodyExt, Full};
use hyper::body::{Body, Incoming};
use hyper::http::{HeaderName, HeaderValue};
use hyper::{header, HeaderMap, Request, Response, StatusCode};
use http_body_util::BodyExt;
use http_body_util::Full;
use hyper::body::Body;
use hyper::body::Incoming;
use hyper::header;
use hyper::http::HeaderName;
use hyper::http::HeaderValue;
use hyper::Response;
use hyper::StatusCode;
use hyper::{HeaderMap, Request};
use pq_proto::StartupMessageParamsBuilder;
use serde::Serialize;
use serde_json::Value;
use tokio::time;
use tokio_postgres::error::{DbError, ErrorPosition, SqlState};
use tokio_postgres::{GenericClient, IsolationLevel, NoTls, ReadyForQueryStatus, Transaction};
use tokio_postgres::error::DbError;
use tokio_postgres::error::ErrorPosition;
use tokio_postgres::error::SqlState;
use tokio_postgres::GenericClient;
use tokio_postgres::IsolationLevel;
use tokio_postgres::NoTls;
use tokio_postgres::ReadyForQueryStatus;
use tokio_postgres::Transaction;
use tokio_util::sync::CancellationToken;
use tracing::{error, info};
use tracing::error;
use tracing::info;
use typed_json::json;
use url::Url;
use urlencoding;
use utils::http::error::ApiError;
use super::backend::{LocalProxyConnError, PoolingBackend};
use super::conn_pool::{AuthData, ConnInfoWithAuth};
use super::conn_pool_lib::{self, ConnInfo};
use super::http_util::json_response;
use super::json::{json_to_pg_text, pg_text_row_to_json, JsonConversionError};
use super::local_conn_pool;
use crate::auth::backend::{ComputeCredentialKeys, ComputeUserInfo};
use crate::auth::{endpoint_sni, ComputeUserInfoParseError};
use crate::config::{AuthenticationConfig, HttpConfig, ProxyConfig, TlsConfig};
use crate::auth::backend::ComputeCredentialKeys;
use crate::auth::backend::ComputeUserInfo;
use crate::auth::endpoint_sni;
use crate::auth::ComputeUserInfoParseError;
use crate::config::AuthenticationConfig;
use crate::config::HttpConfig;
use crate::config::ProxyConfig;
use crate::config::TlsConfig;
use crate::context::RequestMonitoring;
use crate::error::{ErrorKind, ReportableError, UserFacingError};
use crate::metrics::{HttpDirection, Metrics};
use crate::proxy::{run_until_cancelled, NeonOptions};
use crate::error::ErrorKind;
use crate::error::ReportableError;
use crate::error::UserFacingError;
use crate::metrics::HttpDirection;
use crate::metrics::Metrics;
use crate::proxy::run_until_cancelled;
use crate::proxy::NeonOptions;
use crate::serverless::backend::HttpConnError;
use crate::usage_metrics::MetricCounter;
use crate::usage_metrics::MetricCounterRecorder;
use crate::DbName;
use crate::RoleName;
use crate::usage_metrics::{MetricCounter, MetricCounterRecorder};
use crate::{DbName, RoleName};
use super::backend::LocalProxyConnError;
use super::backend::PoolingBackend;
use super::conn_pool;
use super::conn_pool::AuthData;
use super::conn_pool::ConnInfo;
use super::conn_pool::ConnInfoWithAuth;
use super::http_util::json_response;
use super::json::json_to_pg_text;
use super::json::pg_text_row_to_json;
use super::json::JsonConversionError;
use super::local_conn_pool;
#[derive(serde::Deserialize)]
#[serde(rename_all = "camelCase")]
@@ -609,8 +641,7 @@ async fn handle_db_inner(
let client = match keys.keys {
ComputeCredentialKeys::JwtPayload(payload) if is_local_proxy => {
let mut client = backend.connect_to_local_postgres(ctx, conn_info).await?;
let (cli_inner, _dsc) = client.client_inner();
cli_inner.set_jwt_session(&payload).await?;
client.set_jwt_session(&payload).await?;
Client::Local(client)
}
_ => {
@@ -1024,12 +1055,12 @@ async fn query_to_json<T: GenericClient>(
}
enum Client {
Remote(conn_pool_lib::Client<tokio_postgres::Client>),
Remote(conn_pool::Client<tokio_postgres::Client>),
Local(local_conn_pool::LocalClient<tokio_postgres::Client>),
}
enum Discard<'a> {
Remote(conn_pool_lib::Discard<'a, tokio_postgres::Client>),
Remote(conn_pool::Discard<'a, tokio_postgres::Client>),
Local(local_conn_pool::Discard<'a, tokio_postgres::Client>),
}
@@ -1044,7 +1075,7 @@ impl Client {
fn inner(&mut self) -> (&mut tokio_postgres::Client, Discard<'_>) {
match self {
Client::Remote(client) => {
let (c, d) = client.inner_mut();
let (c, d) = client.inner();
(c, Discard::Remote(d))
}
Client::Local(local_client) => {

View File

@@ -1,7 +1,13 @@
use std::pin::Pin;
use std::sync::Arc;
use std::task::{ready, Context, Poll};
use crate::proxy::ErrorSource;
use crate::{
cancellation::CancellationHandlerMain,
config::ProxyConfig,
context::RequestMonitoring,
error::{io_error, ReportableError},
metrics::Metrics,
proxy::{handle_client, ClientMode},
rate_limiter::EndpointRateLimiter,
};
use anyhow::Context as _;
use bytes::{Buf, BufMut, Bytes, BytesMut};
use framed_websockets::{Frame, OpCode, WebSocketServer};
@@ -9,17 +15,15 @@ use futures::{Sink, Stream};
use hyper::upgrade::OnUpgrade;
use hyper_util::rt::TokioIo;
use pin_project_lite::pin_project;
use std::{
pin::Pin,
sync::Arc,
task::{ready, Context, Poll},
};
use tokio::io::{self, AsyncBufRead, AsyncRead, AsyncWrite, ReadBuf};
use tracing::warn;
use crate::cancellation::CancellationHandlerMain;
use crate::config::ProxyConfig;
use crate::context::RequestMonitoring;
use crate::error::{io_error, ReportableError};
use crate::metrics::Metrics;
use crate::proxy::{handle_client, ClientMode, ErrorSource};
use crate::rate_limiter::EndpointRateLimiter;
pin_project! {
/// This is a wrapper around a [`WebSocketStream`] that
/// implements [`AsyncRead`] and [`AsyncWrite`].
@@ -180,11 +184,14 @@ mod tests {
use framed_websockets::WebSocketServer;
use futures::{SinkExt, StreamExt};
use tokio::io::{duplex, AsyncReadExt, AsyncWriteExt};
use tokio::task::JoinSet;
use tokio_tungstenite::tungstenite::protocol::Role;
use tokio_tungstenite::tungstenite::Message;
use tokio_tungstenite::WebSocketStream;
use tokio::{
io::{duplex, AsyncReadExt, AsyncWriteExt},
task::JoinSet,
};
use tokio_tungstenite::{
tungstenite::{protocol::Role, Message},
WebSocketStream,
};
use super::WebSocketRw;

View File

@@ -1,20 +1,19 @@
use std::pin::Pin;
use std::sync::Arc;
use std::{io, task};
use crate::config::TlsServerEndPoint;
use crate::error::{ErrorKind, ReportableError, UserFacingError};
use crate::metrics::Metrics;
use bytes::BytesMut;
use pq_proto::framed::{ConnectionError, Framed};
use pq_proto::{BeMessage, FeMessage, FeStartupPacket, ProtocolError};
use rustls::ServerConfig;
use std::pin::Pin;
use std::sync::Arc;
use std::{io, task};
use thiserror::Error;
use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
use tokio_rustls::server::TlsStream;
use tracing::debug;
use crate::config::TlsServerEndPoint;
use crate::error::{ErrorKind, ReportableError, UserFacingError};
use crate::metrics::Metrics;
/// Stream wrapper which implements libpq's protocol.
///
/// NOTE: This object deliberately doesn't implement [`AsyncRead`]

View File

@@ -1,33 +1,36 @@
//! Periodically collect proxy consumption metrics
//! and push them to a HTTP endpoint.
use std::convert::Infallible;
use std::pin::pin;
use std::sync::atomic::{AtomicU64, AtomicUsize, Ordering};
use std::sync::Arc;
use std::time::Duration;
use crate::{
config::{MetricBackupCollectionConfig, MetricCollectionConfig},
context::parquet::{FAILED_UPLOAD_MAX_RETRIES, FAILED_UPLOAD_WARN_THRESHOLD},
http,
intern::{BranchIdInt, EndpointIdInt},
};
use anyhow::Context;
use async_compression::tokio::write::GzipEncoder;
use bytes::Bytes;
use chrono::{DateTime, Datelike, Timelike, Utc};
use consumption_metrics::{idempotency_key, Event, EventChunk, EventType, CHUNK_SIZE};
use dashmap::mapref::entry::Entry;
use dashmap::DashMap;
use dashmap::{mapref::entry::Entry, DashMap};
use futures::future::select;
use once_cell::sync::Lazy;
use remote_storage::{GenericRemoteStorage, RemotePath, TimeoutOrCancel};
use serde::{Deserialize, Serialize};
use std::{
convert::Infallible,
pin::pin,
sync::{
atomic::{AtomicU64, AtomicUsize, Ordering},
Arc,
},
time::Duration,
};
use tokio::io::AsyncWriteExt;
use tokio_util::sync::CancellationToken;
use tracing::{error, info, instrument, trace, warn};
use utils::backoff;
use uuid::{NoContext, Timestamp};
use crate::config::{MetricBackupCollectionConfig, MetricCollectionConfig};
use crate::context::parquet::{FAILED_UPLOAD_MAX_RETRIES, FAILED_UPLOAD_WARN_THRESHOLD};
use crate::http;
use crate::intern::{BranchIdInt, EndpointIdInt};
const PROXY_IO_BYTES_PER_CLIENT: &str = "proxy_io_bytes_per_client";
const HTTP_REPORTING_REQUEST_TIMEOUT: Duration = Duration::from_secs(10);
@@ -482,23 +485,19 @@ async fn upload_events_chunk(
#[cfg(test)]
mod tests {
use std::sync::{Arc, Mutex};
use super::*;
use crate::{http, BranchId, EndpointId};
use anyhow::Error;
use chrono::Utc;
use consumption_metrics::{Event, EventChunk};
use http_body_util::BodyExt;
use hyper::body::Incoming;
use hyper::server::conn::http1;
use hyper::service::service_fn;
use hyper::{Request, Response};
use hyper::{body::Incoming, server::conn::http1, service::service_fn, Request, Response};
use hyper_util::rt::TokioIo;
use std::sync::{Arc, Mutex};
use tokio::net::TcpListener;
use url::Url;
use super::*;
use crate::{http, BranchId, EndpointId};
#[tokio::test]
async fn metrics() {
type Report = EventChunk<'static, Event<Ids, String>>;

Some files were not shown because too many files have changed in this diff Show More