Compare commits

..

16 Commits

Author SHA1 Message Date
Conrad Ludgate
631139ceeb turns out the boxing isn't necessary, we just needed to massage the stack usage properly 2025-05-30 08:47:44 +01:00
Conrad Ludgate
fd43058bd7 optimise passthrough calling convention to further reduce memory 2025-05-29 18:35:24 +01:00
Conrad Ludgate
cf07c5b5f9 dont box handle_client anymore and move spawning passthrough into handle_client so we don't need to move a heavy object in return position anymore 2025-05-29 18:20:29 +01:00
Conrad Ludgate
11bb84c38d save 1000 bytes by removing instrument 2025-05-29 17:56:25 +01:00
Conrad Ludgate
219c72c24c optimise proxy_pass memory size a little, also boxing requestcontext since it is large 2025-05-29 17:52:26 +01:00
Conrad Ludgate
0633cd6385 small changes to connect compute mechanism/backend handling 2025-05-29 16:21:55 +01:00
Conrad Ludgate
0cdb0c5704 reuse the same tracker token for websockets and http 2025-05-29 16:04:14 +01:00
Conrad Ludgate
eefac5d78b box the connect to compute task 2025-05-29 15:58:28 +01:00
Conrad Ludgate
7d1c908b1b box authenticate task 2025-05-29 15:55:17 +01:00
Conrad Ludgate
cfa2813446 remove unnecessary aux field from passthrough 2025-05-29 15:51:57 +01:00
Conrad Ludgate
034bdb1552 move more work inside handshake 2025-05-29 15:50:10 +01:00
Conrad Ludgate
8b1ffa1718 simplify cplane authentication 2025-05-29 15:46:40 +01:00
Conrad Ludgate
2d3ea77953 box the handshake task 2025-05-29 15:39:33 +01:00
Conrad Ludgate
3124729f53 spawn passthrough as a separate task to reduce influence from the handshake task 2025-05-29 15:21:54 +01:00
Conrad Ludgate
6463eb38be manually handle task tracker tokens 2025-05-29 15:19:03 +01:00
Conrad Ludgate
ae506fd791 proxy: remove unused ip return value 2025-05-29 15:04:40 +01:00
163 changed files with 4232 additions and 7849 deletions

134
Cargo.lock generated
View File

@@ -1086,25 +1086,6 @@ version = "0.3.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "37b2a672a2cb129a2e41c10b1224bb368f9f37a2b16b612598138befd7b37eb5"
[[package]]
name = "cbindgen"
version = "0.28.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "eadd868a2ce9ca38de7eeafdcec9c7065ef89b42b32f0839278d55f35c54d1ff"
dependencies = [
"clap",
"heck 0.4.1",
"indexmap 2.9.0",
"log",
"proc-macro2",
"quote",
"serde",
"serde_json",
"syn 2.0.100",
"tempfile",
"toml",
]
[[package]]
name = "cc"
version = "1.2.16"
@@ -1231,7 +1212,7 @@ version = "4.5.18"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "4ac6a0c7b1a9e9a5186361f67dfa1b88213572f427fb9ab038efb2bd8c582dab"
dependencies = [
"heck 0.5.0",
"heck",
"proc-macro2",
"quote",
"syn 2.0.100",
@@ -1289,14 +1270,6 @@ dependencies = [
"unicode-width",
]
[[package]]
name = "communicator"
version = "0.1.0"
dependencies = [
"cbindgen",
"neon-shmem",
]
[[package]]
name = "compute_api"
version = "0.1.0"
@@ -1963,7 +1936,7 @@ checksum = "0892a17df262a24294c382f0d5997571006e7a4348b4327557c4ff1cd4a8bccc"
dependencies = [
"darling",
"either",
"heck 0.5.0",
"heck",
"proc-macro2",
"quote",
"syn 2.0.100",
@@ -2527,18 +2500,6 @@ dependencies = [
"wasm-bindgen",
]
[[package]]
name = "getrandom"
version = "0.3.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "26145e563e54f2cadc477553f1ec5ee650b00862f0a58bcd12cbdc5f0ea2d2f4"
dependencies = [
"cfg-if",
"libc",
"r-efi",
"wasi 0.14.2+wasi-0.2.4",
]
[[package]]
name = "gettid"
version = "0.1.3"
@@ -2751,12 +2712,6 @@ dependencies = [
"http 1.1.0",
]
[[package]]
name = "heck"
version = "0.4.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "95505c38b4572b2d910cecb0281560f54b440a19336cbbcb27bf6ce6adc6f5a8"
[[package]]
name = "heck"
version = "0.5.0"
@@ -3693,7 +3648,7 @@ version = "0.0.22"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b9e6777fc80a575f9503d908c8b498782a6c3ee88a06cb416dc3941401e43b94"
dependencies = [
"heck 0.5.0",
"heck",
"proc-macro2",
"quote",
"syn 2.0.100",
@@ -3755,7 +3710,7 @@ dependencies = [
"procfs",
"prometheus",
"rand 0.8.5",
"rand_distr 0.4.3",
"rand_distr",
"twox-hash",
]
@@ -3844,8 +3799,6 @@ name = "neon-shmem"
version = "0.1.0"
dependencies = [
"nix 0.30.1",
"rand 0.9.1",
"rand_distr 0.5.1",
"tempfile",
"thiserror 1.0.69",
"workspace_hack",
@@ -4283,8 +4236,6 @@ name = "pagebench"
version = "0.1.0"
dependencies = [
"anyhow",
"async-trait",
"bytes",
"camino",
"clap",
"futures",
@@ -4293,15 +4244,12 @@ dependencies = [
"humantime-serde",
"pageserver_api",
"pageserver_client",
"pageserver_page_api",
"rand 0.8.5",
"reqwest",
"serde",
"serde_json",
"tokio",
"tokio-stream",
"tokio-util",
"tonic 0.13.1",
"tracing",
"utils",
"workspace_hack",
@@ -4357,7 +4305,6 @@ dependencies = [
"hashlink",
"hex",
"hex-literal",
"http 1.1.0",
"http-utils",
"humantime",
"humantime-serde",
@@ -4420,7 +4367,6 @@ dependencies = [
"toml_edit",
"tonic 0.13.1",
"tonic-reflection",
"tower 0.5.2",
"tracing",
"tracing-utils",
"twox-hash",
@@ -4517,6 +4463,7 @@ dependencies = [
"pageserver_api",
"postgres_ffi",
"prost 0.13.5",
"smallvec",
"thiserror 1.0.69",
"tonic 0.13.1",
"tonic-build",
@@ -5139,7 +5086,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "22505a5c94da8e3b7c2996394d1c933236c4d743e81a410bcca4e6989fc066a4"
dependencies = [
"bytes",
"heck 0.5.0",
"heck",
"itertools 0.12.1",
"log",
"multimap",
@@ -5160,7 +5107,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "0c1318b19085f08681016926435853bbf7858f9c082d0999b80550ff5d9abe15"
dependencies = [
"bytes",
"heck 0.5.0",
"heck",
"itertools 0.12.1",
"log",
"multimap",
@@ -5285,7 +5232,7 @@ dependencies = [
"postgres_backend",
"pq_proto",
"rand 0.8.5",
"rand_distr 0.4.3",
"rand_distr",
"rcgen",
"redis",
"regex",
@@ -5389,12 +5336,6 @@ dependencies = [
"proc-macro2",
]
[[package]]
name = "r-efi"
version = "5.2.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "74765f6d916ee2faa39bc8e68e4f3ed8949b48cccdac59983d287a7cb71ce9c5"
[[package]]
name = "rand"
version = "0.7.3"
@@ -5419,16 +5360,6 @@ dependencies = [
"rand_core 0.6.4",
]
[[package]]
name = "rand"
version = "0.9.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "9fbfd9d094a40bf3ae768db9361049ace4c0e04a4fd6b359518bd7b73a73dd97"
dependencies = [
"rand_chacha 0.9.0",
"rand_core 0.9.3",
]
[[package]]
name = "rand_chacha"
version = "0.2.2"
@@ -5449,16 +5380,6 @@ dependencies = [
"rand_core 0.6.4",
]
[[package]]
name = "rand_chacha"
version = "0.9.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d3022b5f1df60f26e1ffddd6c66e8aa15de382ae63b3a0c1bfc0e4d3e3f325cb"
dependencies = [
"ppv-lite86",
"rand_core 0.9.3",
]
[[package]]
name = "rand_core"
version = "0.5.1"
@@ -5477,15 +5398,6 @@ dependencies = [
"getrandom 0.2.11",
]
[[package]]
name = "rand_core"
version = "0.9.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "99d9a13982dcf210057a8a78572b2217b667c3beacbf3a0d8b454f6f82837d38"
dependencies = [
"getrandom 0.3.3",
]
[[package]]
name = "rand_distr"
version = "0.4.3"
@@ -5496,16 +5408,6 @@ dependencies = [
"rand 0.8.5",
]
[[package]]
name = "rand_distr"
version = "0.5.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "6a8615d50dcf34fa31f7ab52692afec947c4dd0ab803cc87cb3b0b4570ff7463"
dependencies = [
"num-traits",
"rand 0.9.1",
]
[[package]]
name = "rand_hc"
version = "0.2.0"
@@ -6992,7 +6894,7 @@ version = "0.26.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "4c6bee85a5a24955dc440386795aa378cd9cf82acd5f764469152d2270e581be"
dependencies = [
"heck 0.5.0",
"heck",
"proc-macro2",
"quote",
"rustversion",
@@ -8291,15 +8193,6 @@ version = "0.11.0+wasi-snapshot-preview1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "9c8d87e72b64a3b4db28d11ce29237c246188f4f51057d65a7eab63b7987e423"
[[package]]
name = "wasi"
version = "0.14.2+wasi-0.2.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "9683f9a5a998d873c0d21fcbe3c083009670149a8fab228644b8bd36b2c48cb3"
dependencies = [
"wit-bindgen-rt",
]
[[package]]
name = "wasite"
version = "0.1.0"
@@ -8657,15 +8550,6 @@ dependencies = [
"windows-sys 0.48.0",
]
[[package]]
name = "wit-bindgen-rt"
version = "0.39.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "6f42320e61fe2cfd34354ecb597f86f413484a798ba44a8ca1165c58d42da6c1"
dependencies = [
"bitflags 2.8.0",
]
[[package]]
name = "workspace_hack"
version = "0.1.0"

View File

@@ -44,7 +44,6 @@ members = [
"libs/proxy/postgres-types2",
"libs/proxy/tokio-postgres2",
"endpoint_storage",
"pgxn/neon/communicator",
]
[workspace.package]
@@ -252,7 +251,6 @@ desim = { version = "0.1", path = "./libs/desim" }
endpoint_storage = { version = "0.0.1", path = "./endpoint_storage/" }
http-utils = { version = "0.1", path = "./libs/http-utils/" }
metrics = { version = "0.1", path = "./libs/metrics/" }
neon-shmem = { version = "0.1", path = "./libs/neon-shmem/" }
pageserver = { path = "./pageserver" }
pageserver_api = { version = "0.1", path = "./libs/pageserver_api/" }
pageserver_client = { path = "./pageserver/client" }
@@ -280,7 +278,6 @@ walproposer = { version = "0.1", path = "./libs/walproposer/" }
workspace_hack = { version = "0.1", path = "./workspace_hack/" }
## Build dependencies
cbindgen = "0.28.0"
criterion = "0.5.1"
rcgen = "0.13"
rstest = "0.18"

View File

@@ -18,12 +18,10 @@ ifeq ($(BUILD_TYPE),release)
PG_LDFLAGS = $(LDFLAGS)
# Unfortunately, `--profile=...` is a nightly feature
CARGO_BUILD_FLAGS += --release
NEON_CARGO_ARTIFACT_TARGET_DIR = $(ROOT_PROJECT_DIR)/target/release
else ifeq ($(BUILD_TYPE),debug)
PG_CONFIGURE_OPTS = --enable-debug --with-openssl --enable-cassert --enable-depend
PG_CFLAGS += -O0 -g3 $(CFLAGS)
PG_LDFLAGS = $(LDFLAGS)
NEON_CARGO_ARTIFACT_TARGET_DIR = $(ROOT_PROJECT_DIR)/target/debug
else
$(error Bad build type '$(BUILD_TYPE)', see Makefile for options)
endif
@@ -182,16 +180,11 @@ postgres-check-%: postgres-%
.PHONY: neon-pg-ext-%
neon-pg-ext-%: postgres-%
+@echo "Compiling communicator $*"
$(CARGO_CMD_PREFIX) cargo build -p communicator $(CARGO_BUILD_FLAGS)
+@echo "Compiling neon $*"
mkdir -p $(POSTGRES_INSTALL_DIR)/build/neon-$*
$(MAKE) PG_CONFIG=$(POSTGRES_INSTALL_DIR)/$*/bin/pg_config COPT='$(COPT)' \
LIBCOMMUNICATOR_PATH=$(NEON_CARGO_ARTIFACT_TARGET_DIR) \
-C $(POSTGRES_INSTALL_DIR)/build/neon-$* \
-f $(ROOT_PROJECT_DIR)/pgxn/neon/Makefile install
+@echo "Compiling neon_walredo $*"
mkdir -p $(POSTGRES_INSTALL_DIR)/build/neon-walredo-$*
$(MAKE) PG_CONFIG=$(POSTGRES_INSTALL_DIR)/$*/bin/pg_config COPT='$(COPT)' \

View File

@@ -310,13 +310,13 @@ RUN curl -sSO https://static.rust-lang.org/rustup/dist/$(uname -m)-unknown-linux
. "$HOME/.cargo/env" && \
cargo --version && rustup --version && \
rustup component add llvm-tools rustfmt clippy && \
cargo install rustfilt --version ${RUSTFILT_VERSION} --locked && \
cargo install cargo-hakari --version ${CARGO_HAKARI_VERSION} --locked && \
cargo install cargo-deny --version ${CARGO_DENY_VERSION} --locked && \
cargo install cargo-hack --version ${CARGO_HACK_VERSION} --locked && \
cargo install cargo-nextest --version ${CARGO_NEXTEST_VERSION} --locked && \
cargo install cargo-chef --version ${CARGO_CHEF_VERSION} --locked && \
cargo install diesel_cli --version ${CARGO_DIESEL_CLI_VERSION} --locked \
cargo install rustfilt --version ${RUSTFILT_VERSION} && \
cargo install cargo-hakari --version ${CARGO_HAKARI_VERSION} && \
cargo install cargo-deny --locked --version ${CARGO_DENY_VERSION} && \
cargo install cargo-hack --version ${CARGO_HACK_VERSION} && \
cargo install cargo-nextest --version ${CARGO_NEXTEST_VERSION} && \
cargo install cargo-chef --locked --version ${CARGO_CHEF_VERSION} && \
cargo install diesel_cli --version ${CARGO_DIESEL_CLI_VERSION} \
--features postgres-bundled --no-default-features && \
rm -rf /home/nonroot/.cargo/registry && \
rm -rf /home/nonroot/.cargo/git

View File

@@ -297,7 +297,6 @@ RUN ./autogen.sh && \
./configure --with-sfcgal=/usr/local/bin/sfcgal-config && \
make -j $(getconf _NPROCESSORS_ONLN) && \
make -j $(getconf _NPROCESSORS_ONLN) install && \
make staged-install && \
cd extensions/postgis && \
make clean && \
make -j $(getconf _NPROCESSORS_ONLN) install && \
@@ -603,7 +602,7 @@ RUN case "${PG_VERSION:?}" in \
;; \
esac && \
wget https://github.com/knizhnik/online_advisor/archive/refs/tags/1.0.tar.gz -O online_advisor.tar.gz && \
echo "37dcadf8f7cc8d6cc1f8831276ee245b44f1b0274f09e511e47a67738ba9ed0f online_advisor.tar.gz" | sha256sum --check && \
echo "059b7d9e5a90013a58bdd22e9505b88406ce05790675eb2d8434e5b215652d54 online_advisor.tar.gz" | sha256sum --check && \
mkdir online_advisor-src && cd online_advisor-src && tar xzf ../online_advisor.tar.gz --strip-components=1 -C .
FROM pg-build AS online_advisor-build
@@ -1181,14 +1180,14 @@ RUN cd exts/rag && \
RUN cd exts/rag_bge_small_en_v15 && \
sed -i 's/pgrx = "0.14.1"/pgrx = { version = "0.14.1", features = [ "unsafe-postgres" ] }/g' Cargo.toml && \
ORT_LIB_LOCATION=/ext-src/onnxruntime-src/build/Linux \
REMOTE_ONNX_URL=http://pg-ext-s3-gateway.pg-ext-s3-gateway.svc.cluster.local/pgrag-data/bge_small_en_v15.onnx \
REMOTE_ONNX_URL=http://pg-ext-s3-gateway/pgrag-data/bge_small_en_v15.onnx \
cargo pgrx install --release --features remote_onnx && \
echo "trusted = true" >> /usr/local/pgsql/share/extension/rag_bge_small_en_v15.control
RUN cd exts/rag_jina_reranker_v1_tiny_en && \
sed -i 's/pgrx = "0.14.1"/pgrx = { version = "0.14.1", features = [ "unsafe-postgres" ] }/g' Cargo.toml && \
ORT_LIB_LOCATION=/ext-src/onnxruntime-src/build/Linux \
REMOTE_ONNX_URL=http://pg-ext-s3-gateway.pg-ext-s3-gateway.svc.cluster.local/pgrag-data/jina_reranker_v1_tiny_en.onnx \
REMOTE_ONNX_URL=http://pg-ext-s3-gateway/pgrag-data/jina_reranker_v1_tiny_en.onnx \
cargo pgrx install --release --features remote_onnx && \
echo "trusted = true" >> /usr/local/pgsql/share/extension/rag_jina_reranker_v1_tiny_en.control
@@ -1843,25 +1842,10 @@ RUN make PG_VERSION="${PG_VERSION:?}" -C compute
FROM pg-build AS extension-tests
ARG PG_VERSION
# This is required for the PostGIS test
RUN apt-get update && case $DEBIAN_VERSION in \
bullseye) \
apt-get install -y libproj19 libgdal28 time; \
;; \
bookworm) \
apt-get install -y libgdal32 libproj25 time; \
;; \
*) \
echo "Unknown Debian version ${DEBIAN_VERSION}" && exit 1 \
;; \
esac
COPY docker-compose/ext-src/ /ext-src/
COPY --from=pg-build /postgres /postgres
COPY --from=postgis-build /usr/local/pgsql/ /usr/local/pgsql/
COPY --from=postgis-build /ext-src/postgis-src /ext-src/postgis-src
COPY --from=postgis-build /sfcgal/* /usr
#COPY --from=postgis-src /ext-src/ /ext-src/
COPY --from=plv8-src /ext-src/ /ext-src/
COPY --from=h3-pg-src /ext-src/h3-pg-src /ext-src/h3-pg-src
COPY --from=postgresql-unit-src /ext-src/ /ext-src/
@@ -1902,7 +1886,6 @@ COPY compute/patches/pg_repack.patch /ext-src
RUN cd /ext-src/pg_repack-src && patch -p1 </ext-src/pg_repack.patch && rm -f /ext-src/pg_repack.patch
COPY --chmod=755 docker-compose/run-tests.sh /run-tests.sh
RUN echo /usr/local/pgsql/lib > /etc/ld.so.conf.d/00-neon.conf && /sbin/ldconfig
RUN apt-get update && apt-get install -y libtap-parser-sourcehandler-pgtap-perl jq \
&& apt clean && rm -rf /ext-src/*.tar.gz /ext-src/*.patch /var/lib/apt/lists/*
ENV PATH=/usr/local/pgsql/bin:$PATH

View File

@@ -1,121 +0,0 @@
pg_settings:
# Common settings for primaries and replicas of all versions.
common:
# Check for client disconnection every 1 minute. By default, Postgres will detect the
# loss of the connection only at the next interaction with the socket, when it waits
# for, receives or sends data, so it will likely waste resources till the end of the
# query execution. There should be no drawbacks in setting this for everyone, so enable
# it by default. If anyone will complain, we can allow editing it.
# https://www.postgresql.org/docs/16/runtime-config-connection.html#GUC-CLIENT-CONNECTION-CHECK-INTERVAL
client_connection_check_interval: "60000" # 1 minute
# ---- IO ----
effective_io_concurrency: "20"
maintenance_io_concurrency: "100"
fsync: "off"
hot_standby: "off"
# We allow users to change this if needed, but by default we
# just don't want to see long-lasting idle transactions, as they
# prevent activity monitor from suspending projects.
idle_in_transaction_session_timeout: "300000" # 5 minutes
listen_addresses: "*"
# --- LOGGING ---- helps investigations
log_connections: "on"
log_disconnections: "on"
# 1GB, unit is KB
log_temp_files: "1048576"
# Disable dumping customer data to logs, both to increase data privacy
# and to reduce the amount the logs.
log_error_verbosity: "terse"
log_min_error_statement: "panic"
max_connections: "100"
# --- WAL ---
# - flush lag is the max amount of WAL that has been generated but not yet stored
# to disk in the page server. A smaller value means less delay after a pageserver
# restart, but if you set it too small you might again need to slow down writes if the
# pageserver cannot flush incoming WAL to disk fast enough. This must be larger
# than the pageserver's checkpoint interval, currently 1 GB! Otherwise you get a
# a deadlock where the compute node refuses to generate more WAL before the
# old WAL has been uploaded to S3, but the pageserver is waiting for more WAL
# to be generated before it is uploaded to S3.
max_replication_flush_lag: "10GB"
max_replication_slots: "10"
# Backpressure configuration:
# - write lag is the max amount of WAL that has been generated by Postgres but not yet
# processed by the page server. Making this smaller reduces the worst case latency
# of a GetPage request, if you request a page that was recently modified. On the other
# hand, if this is too small, the compute node might need to wait on a write if there is a
# hiccup in the network or page server so that the page server has temporarily fallen
# behind.
#
# Previously it was set to 500 MB, but it caused compute being unresponsive under load
# https://github.com/neondatabase/neon/issues/2028
max_replication_write_lag: "500MB"
max_wal_senders: "10"
# A Postgres checkpoint is cheap in storage, as doesn't involve any significant amount
# of real I/O. Only the SLRU buffers and some other small files are flushed to disk.
# However, as long as we have full_page_writes=on, page updates after a checkpoint
# include full-page images which bloats the WAL. So may want to bump max_wal_size to
# reduce the WAL bloating, but at the same it will increase pg_wal directory size on
# compute and can lead to out of disk error on k8s nodes.
max_wal_size: "1024"
wal_keep_size: "0"
wal_level: "replica"
# Reduce amount of WAL generated by default.
wal_log_hints: "off"
# - without wal_sender_timeout set we don't get feedback messages,
# required for backpressure.
wal_sender_timeout: "10000"
# We have some experimental extensions, which we don't want users to install unconsciously.
# To install them, users would need to set the `neon.allow_unstable_extensions` setting.
# There are two of them currently:
# - `pgrag` - https://github.com/neondatabase-labs/pgrag - extension is actually called just `rag`,
# and two dependencies:
# - `rag_bge_small_en_v15`
# - `rag_jina_reranker_v1_tiny_en`
# - `pg_mooncake` - https://github.com/Mooncake-Labs/pg_mooncake/
neon.unstable_extensions: "rag,rag_bge_small_en_v15,rag_jina_reranker_v1_tiny_en,pg_mooncake,anon"
neon.protocol_version: "3"
password_encryption: "scram-sha-256"
# This is important to prevent Postgres from trying to perform
# a local WAL redo after backend crash. It should exit and let
# the systemd or k8s to do a fresh startup with compute_ctl.
restart_after_crash: "off"
# By default 3. We have the following persistent connections in the VM:
# * compute_activity_monitor (from compute_ctl)
# * postgres-exporter (metrics collector; it has 2 connections)
# * sql_exporter (metrics collector; we have 2 instances [1 for us & users; 1 for autoscaling])
# * vm-monitor (to query & change file cache size)
# i.e. total of 6. Let's reserve 7, so there's still at least one left over.
superuser_reserved_connections: "7"
synchronous_standby_names: "walproposer"
replica:
hot_standby: "on"
per_version:
17:
common:
# PostgreSQL 17 has a new IO system called "read stream", which can combine IOs up to some
# size. It still has some issues with readahead, though, so we default to disabled/
# "no combining of IOs" to make sure we get the maximum prefetch depth.
# See also: https://github.com/neondatabase/neon/pull/9860
io_combine_limit: "1"
replica:
# prefetching of blocks referenced in WAL doesn't make sense for us
# Neon hot standby ignores pages that are not in the shared_buffers
recovery_prefetch: "off"
16:
common:
replica:
# prefetching of blocks referenced in WAL doesn't make sense for us
# Neon hot standby ignores pages that are not in the shared_buffers
recovery_prefetch: "off"
15:
common:
replica:
# prefetching of blocks referenced in WAL doesn't make sense for us
# Neon hot standby ignores pages that are not in the shared_buffers
recovery_prefetch: "off"
14:
common:
replica:

View File

@@ -40,7 +40,7 @@ use std::sync::mpsc;
use std::thread;
use std::time::Duration;
use anyhow::{Context, Result, bail};
use anyhow::{Context, Result};
use clap::Parser;
use compute_api::responses::ComputeConfig;
use compute_tools::compute::{
@@ -57,15 +57,31 @@ use tracing::{error, info};
use url::Url;
use utils::failpoint_support;
#[derive(Debug, Parser)]
// Compatibility hack: if the control plane specified any remote-ext-config
// use the default value for extension storage proxy gateway.
// Remove this once the control plane is updated to pass the gateway URL
fn parse_remote_ext_base_url(arg: &str) -> Result<String> {
const FALLBACK_PG_EXT_GATEWAY_BASE_URL: &str =
"http://pg-ext-s3-gateway.pg-ext-s3-gateway.svc.cluster.local";
Ok(if arg.starts_with("http") {
arg
} else {
FALLBACK_PG_EXT_GATEWAY_BASE_URL
}
.to_owned())
}
#[derive(Parser)]
#[command(rename_all = "kebab-case")]
struct Cli {
#[arg(short = 'b', long, default_value = "postgres", env = "POSTGRES_PATH")]
pub pgbin: String,
/// The base URL for the remote extension storage proxy gateway.
#[arg(short = 'r', long, value_parser = Self::parse_remote_ext_base_url)]
pub remote_ext_base_url: Option<Url>,
/// Should be in the form of `http(s)://<gateway-hostname>[:<port>]`.
#[arg(short = 'r', long, value_parser = parse_remote_ext_base_url, alias = "remote-ext-config")]
pub remote_ext_base_url: Option<String>,
/// The port to bind the external listening HTTP server to. Clients running
/// outside the compute will talk to the compute through this port. Keep
@@ -126,25 +142,6 @@ struct Cli {
pub installed_extensions_collection_interval: u64,
}
impl Cli {
/// Parse a URL from an argument. By default, this isn't necessary, but we
/// want to do some sanity checking.
fn parse_remote_ext_base_url(value: &str) -> Result<Url> {
// Remove extra trailing slashes, and add one. We use Url::join() later
// when downloading remote extensions. If the base URL is something like
// http://example.com/pg-ext-s3-gateway, and join() is called with
// something like "xyz", the resulting URL is http://example.com/xyz.
let value = value.trim_end_matches('/').to_owned() + "/";
let url = Url::parse(&value)?;
if url.query_pairs().count() != 0 {
bail!("parameters detected in remote extensions base URL")
}
Ok(url)
}
}
fn main() -> Result<()> {
let cli = Cli::parse();
@@ -271,8 +268,7 @@ fn handle_exit_signal(sig: i32) {
#[cfg(test)]
mod test {
use clap::{CommandFactory, Parser};
use url::Url;
use clap::CommandFactory;
use super::Cli;
@@ -282,41 +278,16 @@ mod test {
}
#[test]
fn verify_remote_ext_base_url() {
let cli = Cli::parse_from([
"compute_ctl",
"--pgdata=test",
"--connstr=test",
"--compute-id=test",
"--remote-ext-base-url",
"https://example.com/subpath",
]);
assert_eq!(
cli.remote_ext_base_url.unwrap(),
Url::parse("https://example.com/subpath/").unwrap()
);
fn parse_pg_ext_gateway_base_url() {
let arg = "http://pg-ext-s3-gateway2";
let result = super::parse_remote_ext_base_url(arg).unwrap();
assert_eq!(result, arg);
let cli = Cli::parse_from([
"compute_ctl",
"--pgdata=test",
"--connstr=test",
"--compute-id=test",
"--remote-ext-base-url",
"https://example.com//",
]);
let arg = "pg-ext-s3-gateway";
let result = super::parse_remote_ext_base_url(arg).unwrap();
assert_eq!(
cli.remote_ext_base_url.unwrap(),
Url::parse("https://example.com").unwrap()
result,
"http://pg-ext-s3-gateway.pg-ext-s3-gateway.svc.cluster.local"
);
Cli::try_parse_from([
"compute_ctl",
"--pgdata=test",
"--connstr=test",
"--compute-id=test",
"--remote-ext-base-url",
"https://example.com?hello=world",
])
.expect_err("URL parameters are not allowed");
}
}

View File

@@ -339,8 +339,6 @@ async fn run_dump_restore(
destination_connstring: String,
) -> Result<(), anyhow::Error> {
let dumpdir = workdir.join("dumpdir");
let num_jobs = num_cpus::get().to_string();
info!("using {num_jobs} jobs for dump/restore");
let common_args = [
// schema mapping (prob suffices to specify them on one side)
@@ -356,7 +354,7 @@ async fn run_dump_restore(
"directory".to_string(),
// concurrency
"--jobs".to_string(),
num_jobs,
num_cpus::get().to_string(),
// progress updates
"--verbose".to_string(),
];

View File

@@ -3,7 +3,7 @@ use chrono::{DateTime, Utc};
use compute_api::privilege::Privilege;
use compute_api::responses::{
ComputeConfig, ComputeCtlConfig, ComputeMetrics, ComputeStatus, LfcOffloadState,
LfcPrewarmState, TlsConfig,
LfcPrewarmState,
};
use compute_api::spec::{
ComputeAudit, ComputeFeature, ComputeMode, ComputeSpec, ExtVersion, PgIdent,
@@ -31,7 +31,6 @@ use std::time::{Duration, Instant};
use std::{env, fs};
use tokio::spawn;
use tracing::{Instrument, debug, error, info, instrument, warn};
use url::Url;
use utils::id::{TenantId, TimelineId};
use utils::lsn::Lsn;
use utils::measured_stream::MeasuredReader;
@@ -97,7 +96,7 @@ pub struct ComputeNodeParams {
pub internal_http_port: u16,
/// the address of extension storage proxy gateway
pub remote_ext_base_url: Option<Url>,
pub remote_ext_base_url: Option<String>,
/// Interval for installed extensions collection
pub installed_extensions_collection_interval: u64,
@@ -396,7 +395,7 @@ impl ComputeNode {
// because QEMU will already have its memory allocated from the host, and
// the necessary binaries will already be cached.
if cli_spec.is_none() {
this.prewarm_postgres_vm_memory()?;
this.prewarm_postgres()?;
}
// Set the up metric with Empty status before starting the HTTP server.
@@ -603,8 +602,6 @@ impl ComputeNode {
});
}
let tls_config = self.tls_config(&pspec.spec);
// If there are any remote extensions in shared_preload_libraries, start downloading them
if pspec.spec.remote_extensions.is_some() {
let (this, spec) = (self.clone(), pspec.spec.clone());
@@ -661,7 +658,7 @@ impl ComputeNode {
info!("tuning pgbouncer");
let pgbouncer_settings = pgbouncer_settings.clone();
let tls_config = tls_config.clone();
let tls_config = self.compute_ctl_config.tls.clone();
// Spawn a background task to do the tuning,
// so that we don't block the main thread that starts Postgres.
@@ -680,10 +677,7 @@ impl ComputeNode {
// Spawn a background task to do the configuration,
// so that we don't block the main thread that starts Postgres.
let mut local_proxy = local_proxy.clone();
local_proxy.tls = tls_config.clone();
let local_proxy = local_proxy.clone();
let _handle = tokio::spawn(async move {
if let Err(err) = local_proxy::configure(&local_proxy) {
error!("error while configuring local_proxy: {err:?}");
@@ -784,7 +778,7 @@ impl ComputeNode {
// Spawn the extension stats background task
self.spawn_extension_stats_task();
if pspec.spec.autoprewarm {
if pspec.spec.prewarm_lfc_on_startup {
self.prewarm_lfc();
}
Ok(())
@@ -1210,15 +1204,13 @@ impl ComputeNode {
let spec = &pspec.spec;
let pgdata_path = Path::new(&self.params.pgdata);
let tls_config = self.tls_config(&pspec.spec);
// Remove/create an empty pgdata directory and put configuration there.
self.create_pgdata()?;
config::write_postgres_conf(
pgdata_path,
&pspec.spec,
self.params.internal_http_port,
tls_config,
&self.compute_ctl_config.tls,
)?;
// Syncing safekeepers is only safe with primary nodes: if a primary
@@ -1314,8 +1306,8 @@ impl ComputeNode {
}
/// Start and stop a postgres process to warm up the VM for startup.
pub fn prewarm_postgres_vm_memory(&self) -> Result<()> {
info!("prewarming VM memory");
pub fn prewarm_postgres(&self) -> Result<()> {
info!("prewarming");
// Create pgdata
let pgdata = &format!("{}.warmup", self.params.pgdata);
@@ -1357,7 +1349,7 @@ impl ComputeNode {
kill(pm_pid, Signal::SIGQUIT)?;
info!("sent SIGQUIT signal");
pg.wait()?;
info!("done prewarming vm memory");
info!("done prewarming");
// clean up
let _ok = fs::remove_dir_all(pgdata);
@@ -1543,22 +1535,14 @@ impl ComputeNode {
.clone(),
);
let mut tls_config = None::<TlsConfig>;
if spec.features.contains(&ComputeFeature::TlsExperimental) {
tls_config = self.compute_ctl_config.tls.clone();
}
let max_concurrent_connections = self.max_service_connections(compute_state, &spec);
// Merge-apply spec & changes to PostgreSQL state.
self.apply_spec_sql(spec.clone(), conf.clone(), max_concurrent_connections)?;
if let Some(local_proxy) = &spec.clone().local_proxy_config {
let mut local_proxy = local_proxy.clone();
local_proxy.tls = tls_config.clone();
info!("configuring local_proxy");
local_proxy::configure(&local_proxy).context("apply_config local_proxy")?;
local_proxy::configure(local_proxy).context("apply_config local_proxy")?;
}
// Run migrations separately to not hold up cold starts
@@ -1610,13 +1594,11 @@ impl ComputeNode {
pub fn reconfigure(&self) -> Result<()> {
let spec = self.state.lock().unwrap().pspec.clone().unwrap().spec;
let tls_config = self.tls_config(&spec);
if let Some(ref pgbouncer_settings) = spec.pgbouncer_settings {
info!("tuning pgbouncer");
let pgbouncer_settings = pgbouncer_settings.clone();
let tls_config = tls_config.clone();
let tls_config = self.compute_ctl_config.tls.clone();
// Spawn a background task to do the tuning,
// so that we don't block the main thread that starts Postgres.
@@ -1634,7 +1616,7 @@ impl ComputeNode {
// Spawn a background task to do the configuration,
// so that we don't block the main thread that starts Postgres.
let mut local_proxy = local_proxy.clone();
local_proxy.tls = tls_config.clone();
local_proxy.tls = self.compute_ctl_config.tls.clone();
tokio::spawn(async move {
if let Err(err) = local_proxy::configure(&local_proxy) {
error!("error while configuring local_proxy: {err:?}");
@@ -1652,7 +1634,7 @@ impl ComputeNode {
pgdata_path,
&spec,
self.params.internal_http_port,
tls_config,
&self.compute_ctl_config.tls,
)?;
if !spec.skip_pg_catalog_updates {
@@ -1772,14 +1754,6 @@ impl ComputeNode {
}
}
pub fn tls_config(&self, spec: &ComputeSpec) -> &Option<TlsConfig> {
if spec.features.contains(&ComputeFeature::TlsExperimental) {
&self.compute_ctl_config.tls
} else {
&None::<TlsConfig>
}
}
/// Update the `last_active` in the shared state, but ensure that it's a more recent one.
pub fn update_last_active(&self, last_active: Option<DateTime<Utc>>) {
let mut state = self.state.lock().unwrap();

View File

@@ -83,7 +83,6 @@ use reqwest::StatusCode;
use tar::Archive;
use tracing::info;
use tracing::log::warn;
use url::Url;
use zstd::stream::read::Decoder;
use crate::metrics::{REMOTE_EXT_REQUESTS_TOTAL, UNKNOWN_HTTP_STATUS};
@@ -159,7 +158,7 @@ fn parse_pg_version(human_version: &str) -> PostgresMajorVersion {
pub async fn download_extension(
ext_name: &str,
ext_path: &RemotePath,
remote_ext_base_url: &Url,
remote_ext_base_url: &str,
pgbin: &str,
) -> Result<u64> {
info!("Download extension {:?} from {:?}", ext_name, ext_path);
@@ -271,14 +270,10 @@ pub fn create_control_files(remote_extensions: &RemoteExtSpec, pgbin: &str) {
}
// Do request to extension storage proxy, e.g.,
// curl http://pg-ext-s3-gateway.pg-ext-s3-gateway.svc.cluster.local/latest/v15/extensions/anon.tar.zst
// curl http://pg-ext-s3-gateway/latest/v15/extensions/anon.tar.zst
// using HTTP GET and return the response body as bytes.
async fn download_extension_tar(remote_ext_base_url: &Url, ext_path: &str) -> Result<Bytes> {
let uri = remote_ext_base_url.join(ext_path).with_context(|| {
format!(
"failed to create the remote extension URI for {ext_path} using {remote_ext_base_url}"
)
})?;
async fn download_extension_tar(remote_ext_base_url: &str, ext_path: &str) -> Result<Bytes> {
let uri = format!("{}/{}", remote_ext_base_url, ext_path);
let filename = Path::new(ext_path)
.file_name()
.unwrap_or_else(|| std::ffi::OsStr::new("unknown"))
@@ -288,7 +283,7 @@ async fn download_extension_tar(remote_ext_base_url: &Url, ext_path: &str) -> Re
info!("Downloading extension file '{}' from uri {}", filename, uri);
match do_extension_server_request(uri).await {
match do_extension_server_request(&uri).await {
Ok(resp) => {
info!("Successfully downloaded remote extension data {}", ext_path);
REMOTE_EXT_REQUESTS_TOTAL
@@ -307,7 +302,7 @@ async fn download_extension_tar(remote_ext_base_url: &Url, ext_path: &str) -> Re
// Do a single remote extensions server request.
// Return result or (error message + stringified status code) in case of any failures.
async fn do_extension_server_request(uri: Url) -> Result<Bytes, (String, String)> {
async fn do_extension_server_request(uri: &str) -> Result<Bytes, (String, String)> {
let resp = reqwest::get(uri).await.map_err(|e| {
(
format!(

View File

@@ -48,9 +48,11 @@ impl JsonResponse {
/// Create an error response related to the compute being in an invalid state
pub(self) fn invalid_status(status: ComputeStatus) -> Response {
Self::error(
Self::create_response(
StatusCode::PRECONDITION_FAILED,
format!("invalid compute status: {status}"),
&GenericAPIError {
error: format!("invalid compute status: {status}"),
},
)
}
}

View File

@@ -22,7 +22,7 @@ pub(in crate::http) async fn configure(
State(compute): State<Arc<ComputeNode>>,
request: Json<ConfigurationRequest>,
) -> Response {
let pspec = match ParsedSpec::try_from(request.0.spec) {
let pspec = match ParsedSpec::try_from(request.spec.clone()) {
Ok(p) => p,
Err(e) => return JsonResponse::error(StatusCode::BAD_REQUEST, e),
};

View File

@@ -13,12 +13,6 @@ use crate::metrics::{PG_CURR_DOWNTIME_MS, PG_TOTAL_DOWNTIME_MS};
const MONITOR_CHECK_INTERVAL: Duration = Duration::from_millis(500);
/// Struct to store runtime state of the compute monitor thread.
/// In theory, this could be a part of `Compute`, but i)
/// this state is expected to be accessed only by single thread,
/// so we don't need to care about locking; ii) `Compute` is
/// already quite big. Thus, it seems to be a good idea to keep
/// all the activity/health monitoring parts here.
struct ComputeMonitor {
compute: Arc<ComputeNode>,
@@ -76,36 +70,12 @@ impl ComputeMonitor {
)
}
/// Check if compute is in some terminal or soon-to-be-terminal
/// state, then return `true`, signalling the caller that it
/// should exit gracefully. Otherwise, return `false`.
fn check_interrupts(&mut self) -> bool {
let compute_status = self.compute.get_status();
if matches!(
compute_status,
ComputeStatus::Terminated | ComputeStatus::TerminationPending | ComputeStatus::Failed
) {
info!(
"compute is in {} status, stopping compute monitor",
compute_status
);
return true;
}
false
}
/// Spin in a loop and figure out the last activity time in the Postgres.
/// Then update it in the shared state. This function currently never
/// errors out explicitly, but there is a graceful termination path.
/// Every time we receive an error trying to check Postgres, we use
/// [`ComputeMonitor::check_interrupts()`] because it could be that
/// compute is being terminated already, then we can exit gracefully
/// to not produce errors' noise in the log.
/// Then update it in the shared state. This function never errors out.
/// NB: the only expected panic is at `Mutex` unwrap(), all other errors
/// should be handled gracefully.
#[instrument(skip_all)]
pub fn run(&mut self) -> anyhow::Result<()> {
pub fn run(&mut self) {
// Suppose that `connstr` doesn't change
let connstr = self.compute.params.connstr.clone();
let conf = self
@@ -123,10 +93,6 @@ impl ComputeMonitor {
info!("starting compute monitor for {}", connstr);
loop {
if self.check_interrupts() {
break;
}
match &mut client {
Ok(cli) => {
if cli.is_closed() {
@@ -134,10 +100,6 @@ impl ComputeMonitor {
downtime_info = self.downtime_info(),
"connection to Postgres is closed, trying to reconnect"
);
if self.check_interrupts() {
break;
}
self.report_down();
// Connection is closed, reconnect and try again.
@@ -149,19 +111,15 @@ impl ComputeMonitor {
self.compute.update_last_active(self.last_active);
}
Err(e) => {
error!(
downtime_info = self.downtime_info(),
"could not check Postgres: {}", e
);
if self.check_interrupts() {
break;
}
// Although we have many places where we can return errors in `check()`,
// normally it shouldn't happen. I.e., we will likely return error if
// connection got broken, query timed out, Postgres returned invalid data, etc.
// In all such cases it's suspicious, so let's report this as downtime.
self.report_down();
error!(
downtime_info = self.downtime_info(),
"could not check Postgres: {}", e
);
// Reconnect to Postgres just in case. During tests, I noticed
// that queries in `check()` can fail with `connection closed`,
@@ -178,10 +136,6 @@ impl ComputeMonitor {
downtime_info = self.downtime_info(),
"could not connect to Postgres: {}, retrying", e
);
if self.check_interrupts() {
break;
}
self.report_down();
// Establish a new connection and try again.
@@ -193,9 +147,6 @@ impl ComputeMonitor {
self.last_checked = Utc::now();
thread::sleep(MONITOR_CHECK_INTERVAL);
}
// Graceful termination path
Ok(())
}
#[instrument(skip_all)]
@@ -478,10 +429,7 @@ pub fn launch_monitor(compute: &Arc<ComputeNode>) -> thread::JoinHandle<()> {
.spawn(move || {
let span = span!(Level::INFO, "compute_monitor");
let _enter = span.enter();
match monitor.run() {
Ok(_) => info!("compute monitor thread terminated gracefully"),
Err(err) => error!("compute monitor thread terminated abnormally {:?}", err),
}
monitor.run();
})
.expect("cannot launch compute monitor thread")
}

View File

@@ -30,7 +30,7 @@ mod pg_helpers_tests {
r#"fsync = off
wal_level = logical
hot_standby = on
autoprewarm = off
prewarm_lfc_on_startup = off
neon.safekeepers = '127.0.0.1:6502,127.0.0.1:6503,127.0.0.1:6501'
wal_log_hints = on
log_connections = on

View File

@@ -747,7 +747,7 @@ impl Endpoint {
logs_export_host: None::<String>,
endpoint_storage_addr: Some(endpoint_storage_addr),
endpoint_storage_token: Some(endpoint_storage_token),
autoprewarm: false,
prewarm_lfc_on_startup: false,
};
// this strange code is needed to support respec() in tests

View File

@@ -513,6 +513,11 @@ impl PageServerNode {
.map(|x| x.parse::<bool>())
.transpose()
.context("Failed to parse 'timeline_offloading' as bool")?,
wal_receiver_protocol_override: settings
.remove("wal_receiver_protocol_override")
.map(serde_json::from_str)
.transpose()
.context("parse `wal_receiver_protocol_override` from json")?,
rel_size_v2_enabled: settings
.remove("rel_size_v2_enabled")
.map(|x| x.parse::<bool>())

View File

@@ -13,6 +13,6 @@ RUN echo 'Acquire::Retries "5";' > /etc/apt/apt.conf.d/80-retries && \
jq \
netcat-openbsd
#This is required for the pg_hintplan test
RUN mkdir -p /ext-src/pg_hint_plan-src /postgres/contrib/file_fdw /ext-src/postgis-src/ && chown postgres /ext-src/pg_hint_plan-src /postgres/contrib/file_fdw /ext-src/postgis-src
RUN mkdir -p /ext-src/pg_hint_plan-src /postgres/contrib/file_fdw && chown postgres /ext-src/pg_hint_plan-src /postgres/contrib/file_fdw
USER postgres

View File

@@ -1,18 +1,18 @@
#!/usr/bin/env bash
#!/bin/bash
set -eux
# Generate a random tenant or timeline ID
#
# Takes a variable name as argument. The result is stored in that variable.
generate_id() {
local -n resvar=${1}
printf -v resvar '%08x%08x%08x%08x' ${SRANDOM} ${SRANDOM} ${SRANDOM} ${SRANDOM}
local -n resvar=$1
printf -v resvar '%08x%08x%08x%08x' $SRANDOM $SRANDOM $SRANDOM $SRANDOM
}
PG_VERSION=${PG_VERSION:-14}
readonly CONFIG_FILE_ORG=/var/db/postgres/configs/config.json
readonly CONFIG_FILE=/tmp/config.json
CONFIG_FILE_ORG=/var/db/postgres/configs/config.json
CONFIG_FILE=/tmp/config.json
# Test that the first library path that the dynamic loader looks in is the path
# that we use for custom compiled software
@@ -20,17 +20,17 @@ first_path="$(ldconfig --verbose 2>/dev/null \
| grep --invert-match ^$'\t' \
| cut --delimiter=: --fields=1 \
| head --lines=1)"
test "${first_path}" = '/usr/local/lib'
test "$first_path" == '/usr/local/lib'
echo "Waiting pageserver become ready."
while ! nc -z pageserver 6400; do
sleep 1
sleep 1;
done
echo "Page server is ready."
cp "${CONFIG_FILE_ORG}" "${CONFIG_FILE}"
cp ${CONFIG_FILE_ORG} ${CONFIG_FILE}
if [[ -n "${TENANT_ID:-}" && -n "${TIMELINE_ID:-}" ]]; then
if [ -n "${TENANT_ID:-}" ] && [ -n "${TIMELINE_ID:-}" ]; then
tenant_id=${TENANT_ID}
timeline_id=${TIMELINE_ID}
else
@@ -41,7 +41,7 @@ else
"http://pageserver:9898/v1/tenant"
)
tenant_id=$(curl "${PARAMS[@]}" | jq -r .[0].id)
if [[ -z "${tenant_id}" || "${tenant_id}" = null ]]; then
if [ -z "${tenant_id}" ] || [ "${tenant_id}" = null ]; then
echo "Create a tenant"
generate_id tenant_id
PARAMS=(
@@ -51,7 +51,7 @@ else
"http://pageserver:9898/v1/tenant/${tenant_id}/location_config"
)
result=$(curl "${PARAMS[@]}")
printf '%s\n' "${result}" | jq .
echo $result | jq .
fi
echo "Check if a timeline present"
@@ -61,7 +61,7 @@ else
"http://pageserver:9898/v1/tenant/${tenant_id}/timeline"
)
timeline_id=$(curl "${PARAMS[@]}" | jq -r .[0].timeline_id)
if [[ -z "${timeline_id}" || "${timeline_id}" = null ]]; then
if [ -z "${timeline_id}" ] || [ "${timeline_id}" = null ]; then
generate_id timeline_id
PARAMS=(
-sbf
@@ -71,7 +71,7 @@ else
"http://pageserver:9898/v1/tenant/${tenant_id}/timeline/"
)
result=$(curl "${PARAMS[@]}")
printf '%s\n' "${result}" | jq .
echo $result | jq .
fi
fi
@@ -82,10 +82,10 @@ else
fi
echo "Adding pgx_ulid"
shared_libraries=$(jq -r '.spec.cluster.settings[] | select(.name=="shared_preload_libraries").value' ${CONFIG_FILE})
sed -i "s|${shared_libraries}|${shared_libraries},${ulid_extension}|" ${CONFIG_FILE}
sed -i "s/${shared_libraries}/${shared_libraries},${ulid_extension}/" ${CONFIG_FILE}
echo "Overwrite tenant id and timeline id in spec file"
sed -i "s|TENANT_ID|${tenant_id}|" ${CONFIG_FILE}
sed -i "s|TIMELINE_ID|${timeline_id}|" ${CONFIG_FILE}
sed -i "s/TENANT_ID/${tenant_id}/" ${CONFIG_FILE}
sed -i "s/TIMELINE_ID/${timeline_id}/" ${CONFIG_FILE}
cat ${CONFIG_FILE}
@@ -93,5 +93,5 @@ echo "Start compute node"
/usr/local/bin/compute_ctl --pgdata /var/db/postgres/compute \
-C "postgresql://cloud_admin@localhost:55433/postgres" \
-b /usr/local/bin/postgres \
--compute-id "compute-${RANDOM}" \
--config "${CONFIG_FILE}"
--compute-id "compute-$RANDOM" \
--config "$CONFIG_FILE"

View File

@@ -186,14 +186,13 @@ services:
neon-test-extensions:
profiles: ["test-extensions"]
image: ${REPOSITORY:-ghcr.io/neondatabase}/neon-test-extensions-v${PG_TEST_VERSION:-${PG_VERSION:-16}}:${TEST_EXTENSIONS_TAG:-${TAG:-latest}}
image: ${REPOSITORY:-ghcr.io/neondatabase}/neon-test-extensions-v${PG_TEST_VERSION:-16}:${TEST_EXTENSIONS_TAG:-${TAG:-latest}}
environment:
- PGUSER=${PGUSER:-cloud_admin}
- PGPASSWORD=${PGPASSWORD:-cloud_admin}
- PGPASSWORD=cloud_admin
entrypoint:
- "/bin/bash"
- "-c"
command:
- sleep 3600
- sleep 1800
depends_on:
- compute

View File

@@ -54,15 +54,6 @@ for pg_version in ${TEST_VERSION_ONLY-14 15 16 17}; do
# It cannot be moved to Dockerfile now because the database directory is created after the start of the container
echo Adding dummy config
docker compose exec compute touch /var/db/postgres/compute/compute_ctl_temp_override.conf
# Prepare for the PostGIS test
docker compose exec compute mkdir -p /tmp/pgis_reg/pgis_reg_tmp
TMPDIR=$(mktemp -d)
docker compose cp neon-test-extensions:/ext-src/postgis-src/raster/test "${TMPDIR}"
docker compose cp neon-test-extensions:/ext-src/postgis-src/regress/00-regress-install "${TMPDIR}"
docker compose exec compute mkdir -p /ext-src/postgis-src/raster /ext-src/postgis-src/regress /ext-src/postgis-src/regress/00-regress-install
docker compose cp "${TMPDIR}/test" compute:/ext-src/postgis-src/raster/test
docker compose cp "${TMPDIR}/00-regress-install" compute:/ext-src/postgis-src/regress
rm -rf "${TMPDIR}"
# The following block copies the files for the pg_hintplan test to the compute node for the extension test in an isolated docker-compose environment
TMPDIR=$(mktemp -d)
docker compose cp neon-test-extensions:/ext-src/pg_hint_plan-src/data "${TMPDIR}/data"
@@ -77,7 +68,7 @@ for pg_version in ${TEST_VERSION_ONLY-14 15 16 17}; do
docker compose exec -T neon-test-extensions bash -c "(cd /postgres && patch -p1)" <"../compute/patches/contrib_pg${pg_version}.patch"
# We are running tests now
rm -f testout.txt testout_contrib.txt
docker compose exec -e USE_PGXS=1 -e SKIP=timescaledb-src,rdkit-src,pg_jsonschema-src,kq_imcx-src,wal2json_2_5-src,rag_jina_reranker_v1_tiny_en-src,rag_bge_small_en_v15-src \
docker compose exec -e USE_PGXS=1 -e SKIP=timescaledb-src,rdkit-src,postgis-src,pg_jsonschema-src,kq_imcx-src,wal2json_2_5-src,rag_jina_reranker_v1_tiny_en-src,rag_bge_small_en_v15-src \
neon-test-extensions /run-tests.sh /ext-src | tee testout.txt && EXT_SUCCESS=1 || EXT_SUCCESS=0
docker compose exec -e SKIP=start-scripts,postgres_fdw,ltree_plpython,jsonb_plpython,jsonb_plperl,hstore_plpython,hstore_plperl,dblink,bool_plperl \
neon-test-extensions /run-tests.sh /postgres/contrib | tee testout_contrib.txt && CONTRIB_SUCCESS=1 || CONTRIB_SUCCESS=0

View File

@@ -1,70 +0,0 @@
# PostGIS Testing in Neon
This directory contains configuration files and patches for running PostGIS tests in the Neon database environment.
## Overview
PostGIS is a spatial database extension for PostgreSQL that adds support for geographic objects. Testing PostGIS compatibility ensures that Neon's modifications to PostgreSQL don't break compatibility with this critical extension.
## PostGIS Versions
- PostgreSQL v17: PostGIS 3.5.0
- PostgreSQL v14/v15/v16: PostGIS 3.3.3
## Test Configuration
The test setup includes:
- `postgis-no-upgrade-test.patch`: Disables upgrade tests by removing the upgrade test section from regress/runtest.mk
- `postgis-regular-v16.patch`: Version-specific patch for PostgreSQL v16
- `postgis-regular-v17.patch`: Version-specific patch for PostgreSQL v17
- `regular-test.sh`: Script to run PostGIS tests as a regular user
- `neon-test.sh`: Script to handle version-specific test configurations
- `raster_outdb_template.sql`: Template for raster tests with explicit file paths
## Excluded Tests
**Important Note:** The test exclusions listed below are specifically for regular-user tests against staging instances. These exclusions are necessary because staging instances run with limited privileges and cannot perform operations requiring superuser access. Docker-compose based tests are not affected by these exclusions.
### Tests Requiring Superuser Permissions
These tests cannot be run as a regular user:
- `estimatedextent`
- `regress/core/legacy`
- `regress/core/typmod`
- `regress/loader/TestSkipANALYZE`
- `regress/loader/TestANALYZE`
### Tests Requiring Filesystem Access
These tests need direct filesystem access that is only possible for superusers:
- `loader/load_outdb`
### Tests with Flaky Results
These tests have assumptions that don't always hold true:
- `regress/core/computed_columns` - Assumes computed columns always outperform alternatives, which is not consistently true
### Tests Requiring Tunable Parameter Modifications
These tests attempt to modify the `postgis.gdal_enabled_drivers` parameter, which is only accessible to superusers:
- `raster/test/regress/rt_wkb`
- `raster/test/regress/rt_addband`
- `raster/test/regress/rt_setbandpath`
- `raster/test/regress/rt_fromgdalraster`
- `raster/test/regress/rt_asgdalraster`
- `raster/test/regress/rt_astiff`
- `raster/test/regress/rt_asjpeg`
- `raster/test/regress/rt_aspng`
- `raster/test/regress/permitted_gdal_drivers`
- Loader tests: `BasicOutDB`, `Tiled10x10`, `Tiled10x10Copy`, `Tiled8x8`, `TiledAuto`, `TiledAutoSkipNoData`, `TiledAutoCopyn`
### Topology Tests (v17 only)
- `populate_topology_layer`
- `renametopogeometrycolumn`
## Other Modifications
- Binary.sql tests are modified to use explicit file paths
- Server-side SQL COPY commands (which require superuser privileges) are converted to client-side `\copy` commands
- Upgrade tests are disabled

View File

@@ -1,6 +0,0 @@
#!/bin/sh
set -ex
cd "$(dirname "$0")"
patch -p1 <"postgis-common-${PG_VERSION}.patch"
trap 'echo Cleaning up; patch -R -p1 <postgis-common-${PG_VERSION}.patch' EXIT
make installcheck-base

View File

@@ -1,37 +0,0 @@
diff --git a/regress/core/tests.mk b/regress/core/tests.mk
index 3abd7bc..64a9254 100644
--- a/regress/core/tests.mk
+++ b/regress/core/tests.mk
@@ -144,11 +144,6 @@ TESTS_SLOW = \
$(top_srcdir)/regress/core/concave_hull_hard \
$(top_srcdir)/regress/core/knn_recheck
-ifeq ($(shell expr "$(POSTGIS_PGSQL_VERSION)" ">=" 120),1)
- TESTS += \
- $(top_srcdir)/regress/core/computed_columns
-endif
-
ifeq ($(shell expr "$(POSTGIS_GEOS_VERSION)" ">=" 30700),1)
# GEOS-3.7 adds:
# ST_FrechetDistance
diff --git a/regress/runtest.mk b/regress/runtest.mk
index c051f03..010e493 100644
--- a/regress/runtest.mk
+++ b/regress/runtest.mk
@@ -24,16 +24,6 @@ check-regress:
POSTGIS_TOP_BUILD_DIR=$(abs_top_builddir) $(PERL) $(top_srcdir)/regress/run_test.pl $(RUNTESTFLAGS) $(RUNTESTFLAGS_INTERNAL) $(TESTS)
- @if echo "$(RUNTESTFLAGS)" | grep -vq -- --upgrade; then \
- echo "Running upgrade test as RUNTESTFLAGS did not contain that"; \
- POSTGIS_TOP_BUILD_DIR=$(abs_top_builddir) $(PERL) $(top_srcdir)/regress/run_test.pl \
- --upgrade \
- $(RUNTESTFLAGS) \
- $(RUNTESTFLAGS_INTERNAL) \
- $(TESTS); \
- else \
- echo "Skipping upgrade test as RUNTESTFLAGS already requested upgrades"; \
- fi
check-long:
$(PERL) $(top_srcdir)/regress/run_test.pl $(RUNTESTFLAGS) $(TESTS) $(TESTS_SLOW)

View File

@@ -1,35 +0,0 @@
diff --git a/regress/core/tests.mk b/regress/core/tests.mk
index 9e05244..90987df 100644
--- a/regress/core/tests.mk
+++ b/regress/core/tests.mk
@@ -143,8 +143,7 @@ TESTS += \
$(top_srcdir)/regress/core/oriented_envelope \
$(top_srcdir)/regress/core/point_coordinates \
$(top_srcdir)/regress/core/out_geojson \
- $(top_srcdir)/regress/core/wrapx \
- $(top_srcdir)/regress/core/computed_columns
+ $(top_srcdir)/regress/core/wrapx
# Slow slow tests
TESTS_SLOW = \
diff --git a/regress/runtest.mk b/regress/runtest.mk
index 4b95b7e..449d5a2 100644
--- a/regress/runtest.mk
+++ b/regress/runtest.mk
@@ -24,16 +24,6 @@ check-regress:
@POSTGIS_TOP_BUILD_DIR=$(abs_top_builddir) $(PERL) $(top_srcdir)/regress/run_test.pl $(RUNTESTFLAGS) $(RUNTESTFLAGS_INTERNAL) $(TESTS)
- @if echo "$(RUNTESTFLAGS)" | grep -vq -- --upgrade; then \
- echo "Running upgrade test as RUNTESTFLAGS did not contain that"; \
- POSTGIS_TOP_BUILD_DIR=$(abs_top_builddir) $(PERL) $(top_srcdir)/regress/run_test.pl \
- --upgrade \
- $(RUNTESTFLAGS) \
- $(RUNTESTFLAGS_INTERNAL) \
- $(TESTS); \
- else \
- echo "Skipping upgrade test as RUNTESTFLAGS already requested upgrades"; \
- fi
check-long:
$(PERL) $(top_srcdir)/regress/run_test.pl $(RUNTESTFLAGS) $(TESTS) $(TESTS_SLOW)

View File

@@ -1,186 +0,0 @@
diff --git a/raster/test/regress/tests.mk b/raster/test/regress/tests.mk
index 00918e1..7e2b6cd 100644
--- a/raster/test/regress/tests.mk
+++ b/raster/test/regress/tests.mk
@@ -17,9 +17,7 @@ override RUNTESTFLAGS_INTERNAL := \
$(RUNTESTFLAGS_INTERNAL) \
--after-upgrade-script $(top_srcdir)/raster/test/regress/hooks/hook-after-upgrade-raster.sql
-RASTER_TEST_FIRST = \
- $(top_srcdir)/raster/test/regress/check_gdal \
- $(top_srcdir)/raster/test/regress/loader/load_outdb
+RASTER_TEST_FIRST =
RASTER_TEST_LAST = \
$(top_srcdir)/raster/test/regress/clean
@@ -33,9 +31,7 @@ RASTER_TEST_IO = \
RASTER_TEST_BASIC_FUNC = \
$(top_srcdir)/raster/test/regress/rt_bytea \
- $(top_srcdir)/raster/test/regress/rt_wkb \
$(top_srcdir)/raster/test/regress/box3d \
- $(top_srcdir)/raster/test/regress/rt_addband \
$(top_srcdir)/raster/test/regress/rt_band \
$(top_srcdir)/raster/test/regress/rt_tile
@@ -73,16 +69,10 @@ RASTER_TEST_BANDPROPS = \
$(top_srcdir)/raster/test/regress/rt_neighborhood \
$(top_srcdir)/raster/test/regress/rt_nearestvalue \
$(top_srcdir)/raster/test/regress/rt_pixelofvalue \
- $(top_srcdir)/raster/test/regress/rt_polygon \
- $(top_srcdir)/raster/test/regress/rt_setbandpath
+ $(top_srcdir)/raster/test/regress/rt_polygon
RASTER_TEST_UTILITY = \
$(top_srcdir)/raster/test/regress/rt_utility \
- $(top_srcdir)/raster/test/regress/rt_fromgdalraster \
- $(top_srcdir)/raster/test/regress/rt_asgdalraster \
- $(top_srcdir)/raster/test/regress/rt_astiff \
- $(top_srcdir)/raster/test/regress/rt_asjpeg \
- $(top_srcdir)/raster/test/regress/rt_aspng \
$(top_srcdir)/raster/test/regress/rt_reclass \
$(top_srcdir)/raster/test/regress/rt_gdalwarp \
$(top_srcdir)/raster/test/regress/rt_gdalcontour \
@@ -120,21 +110,13 @@ RASTER_TEST_SREL = \
RASTER_TEST_BUGS = \
$(top_srcdir)/raster/test/regress/bug_test_car5 \
- $(top_srcdir)/raster/test/regress/permitted_gdal_drivers \
$(top_srcdir)/raster/test/regress/tickets
RASTER_TEST_LOADER = \
$(top_srcdir)/raster/test/regress/loader/Basic \
$(top_srcdir)/raster/test/regress/loader/Projected \
$(top_srcdir)/raster/test/regress/loader/BasicCopy \
- $(top_srcdir)/raster/test/regress/loader/BasicFilename \
- $(top_srcdir)/raster/test/regress/loader/BasicOutDB \
- $(top_srcdir)/raster/test/regress/loader/Tiled10x10 \
- $(top_srcdir)/raster/test/regress/loader/Tiled10x10Copy \
- $(top_srcdir)/raster/test/regress/loader/Tiled8x8 \
- $(top_srcdir)/raster/test/regress/loader/TiledAuto \
- $(top_srcdir)/raster/test/regress/loader/TiledAutoSkipNoData \
- $(top_srcdir)/raster/test/regress/loader/TiledAutoCopyn
+ $(top_srcdir)/raster/test/regress/loader/BasicFilename
RASTER_TESTS := $(RASTER_TEST_FIRST) \
$(RASTER_TEST_METADATA) $(RASTER_TEST_IO) $(RASTER_TEST_BASIC_FUNC) \
diff --git a/regress/core/binary.sql b/regress/core/binary.sql
index 7a36b65..ad78fc7 100644
--- a/regress/core/binary.sql
+++ b/regress/core/binary.sql
@@ -1,4 +1,5 @@
SET client_min_messages TO warning;
+
CREATE SCHEMA tm;
CREATE TABLE tm.geoms (id serial, g geometry);
@@ -31,24 +32,39 @@ SELECT st_force4d(g) FROM tm.geoms WHERE id < 15 ORDER BY id;
INSERT INTO tm.geoms(g)
SELECT st_setsrid(g,4326) FROM tm.geoms ORDER BY id;
-COPY tm.geoms TO :tmpfile WITH BINARY;
+-- define temp file path
+\set tmpfile '/tmp/postgis_binary_test.dat'
+
+-- export
+\set command '\\copy tm.geoms TO ':tmpfile' WITH (FORMAT BINARY)'
+:command
+
+-- import
CREATE TABLE tm.geoms_in AS SELECT * FROM tm.geoms LIMIT 0;
-COPY tm.geoms_in FROM :tmpfile WITH BINARY;
-SELECT 'geometry', count(*) FROM tm.geoms_in i, tm.geoms o WHERE i.id = o.id
- AND ST_OrderingEquals(i.g, o.g);
+\set command '\\copy tm.geoms_in FROM ':tmpfile' WITH (FORMAT BINARY)'
+:command
+
+SELECT 'geometry', count(*) FROM tm.geoms_in i, tm.geoms o
+WHERE i.id = o.id AND ST_OrderingEquals(i.g, o.g);
CREATE TABLE tm.geogs AS SELECT id,g::geography FROM tm.geoms
WHERE geometrytype(g) NOT LIKE '%CURVE%'
AND geometrytype(g) NOT LIKE '%CIRCULAR%'
AND geometrytype(g) NOT LIKE '%SURFACE%'
AND geometrytype(g) NOT LIKE 'TRIANGLE%'
- AND geometrytype(g) NOT LIKE 'TIN%'
-;
+ AND geometrytype(g) NOT LIKE 'TIN%';
-COPY tm.geogs TO :tmpfile WITH BINARY;
+-- export
+\set command '\\copy tm.geogs TO ':tmpfile' WITH (FORMAT BINARY)'
+:command
+
+-- import
CREATE TABLE tm.geogs_in AS SELECT * FROM tm.geogs LIMIT 0;
-COPY tm.geogs_in FROM :tmpfile WITH BINARY;
-SELECT 'geometry', count(*) FROM tm.geogs_in i, tm.geogs o WHERE i.id = o.id
- AND ST_OrderingEquals(i.g::geometry, o.g::geometry);
+\set command '\\copy tm.geogs_in FROM ':tmpfile' WITH (FORMAT BINARY)'
+:command
+
+SELECT 'geometry', count(*) FROM tm.geogs_in i, tm.geogs o
+WHERE i.id = o.id AND ST_OrderingEquals(i.g::geometry, o.g::geometry);
DROP SCHEMA tm CASCADE;
+
diff --git a/regress/core/tests.mk b/regress/core/tests.mk
index 64a9254..94903c3 100644
--- a/regress/core/tests.mk
+++ b/regress/core/tests.mk
@@ -23,7 +23,6 @@ current_dir := $(dir $(abspath $(lastword $(MAKEFILE_LIST))))
RUNTESTFLAGS_INTERNAL += \
--before-upgrade-script $(top_srcdir)/regress/hooks/hook-before-upgrade.sql \
--after-upgrade-script $(top_srcdir)/regress/hooks/hook-after-upgrade.sql \
- --after-create-script $(top_srcdir)/regress/hooks/hook-after-create.sql \
--before-uninstall-script $(top_srcdir)/regress/hooks/hook-before-uninstall.sql
TESTS += \
@@ -40,7 +39,6 @@ TESTS += \
$(top_srcdir)/regress/core/dumppoints \
$(top_srcdir)/regress/core/dumpsegments \
$(top_srcdir)/regress/core/empty \
- $(top_srcdir)/regress/core/estimatedextent \
$(top_srcdir)/regress/core/forcecurve \
$(top_srcdir)/regress/core/flatgeobuf \
$(top_srcdir)/regress/core/geography \
@@ -55,7 +53,6 @@ TESTS += \
$(top_srcdir)/regress/core/out_marc21 \
$(top_srcdir)/regress/core/in_encodedpolyline \
$(top_srcdir)/regress/core/iscollection \
- $(top_srcdir)/regress/core/legacy \
$(top_srcdir)/regress/core/letters \
$(top_srcdir)/regress/core/long_xact \
$(top_srcdir)/regress/core/lwgeom_regress \
@@ -112,7 +109,6 @@ TESTS += \
$(top_srcdir)/regress/core/temporal_knn \
$(top_srcdir)/regress/core/tickets \
$(top_srcdir)/regress/core/twkb \
- $(top_srcdir)/regress/core/typmod \
$(top_srcdir)/regress/core/wkb \
$(top_srcdir)/regress/core/wkt \
$(top_srcdir)/regress/core/wmsservers \
diff --git a/regress/loader/tests.mk b/regress/loader/tests.mk
index 1fc77ac..c3cb9de 100644
--- a/regress/loader/tests.mk
+++ b/regress/loader/tests.mk
@@ -38,7 +38,5 @@ TESTS += \
$(top_srcdir)/regress/loader/Latin1 \
$(top_srcdir)/regress/loader/Latin1-implicit \
$(top_srcdir)/regress/loader/mfile \
- $(top_srcdir)/regress/loader/TestSkipANALYZE \
- $(top_srcdir)/regress/loader/TestANALYZE \
$(top_srcdir)/regress/loader/CharNoWidth
diff --git a/regress/run_test.pl b/regress/run_test.pl
index 0ec5b2d..1c331f4 100755
--- a/regress/run_test.pl
+++ b/regress/run_test.pl
@@ -147,7 +147,6 @@ $ENV{"LANG"} = "C";
# Add locale info to the psql options
# Add pg12 precision suppression
my $PGOPTIONS = $ENV{"PGOPTIONS"};
-$PGOPTIONS .= " -c lc_messages=C";
$PGOPTIONS .= " -c client_min_messages=NOTICE";
$PGOPTIONS .= " -c extra_float_digits=0";
$ENV{"PGOPTIONS"} = $PGOPTIONS;

View File

@@ -1,208 +0,0 @@
diff --git a/raster/test/regress/tests.mk b/raster/test/regress/tests.mk
index 00918e1..7e2b6cd 100644
--- a/raster/test/regress/tests.mk
+++ b/raster/test/regress/tests.mk
@@ -17,9 +17,7 @@ override RUNTESTFLAGS_INTERNAL := \
$(RUNTESTFLAGS_INTERNAL) \
--after-upgrade-script $(top_srcdir)/raster/test/regress/hooks/hook-after-upgrade-raster.sql
-RASTER_TEST_FIRST = \
- $(top_srcdir)/raster/test/regress/check_gdal \
- $(top_srcdir)/raster/test/regress/loader/load_outdb
+RASTER_TEST_FIRST =
RASTER_TEST_LAST = \
$(top_srcdir)/raster/test/regress/clean
@@ -33,9 +31,7 @@ RASTER_TEST_IO = \
RASTER_TEST_BASIC_FUNC = \
$(top_srcdir)/raster/test/regress/rt_bytea \
- $(top_srcdir)/raster/test/regress/rt_wkb \
$(top_srcdir)/raster/test/regress/box3d \
- $(top_srcdir)/raster/test/regress/rt_addband \
$(top_srcdir)/raster/test/regress/rt_band \
$(top_srcdir)/raster/test/regress/rt_tile
@@ -73,16 +69,10 @@ RASTER_TEST_BANDPROPS = \
$(top_srcdir)/raster/test/regress/rt_neighborhood \
$(top_srcdir)/raster/test/regress/rt_nearestvalue \
$(top_srcdir)/raster/test/regress/rt_pixelofvalue \
- $(top_srcdir)/raster/test/regress/rt_polygon \
- $(top_srcdir)/raster/test/regress/rt_setbandpath
+ $(top_srcdir)/raster/test/regress/rt_polygon
RASTER_TEST_UTILITY = \
$(top_srcdir)/raster/test/regress/rt_utility \
- $(top_srcdir)/raster/test/regress/rt_fromgdalraster \
- $(top_srcdir)/raster/test/regress/rt_asgdalraster \
- $(top_srcdir)/raster/test/regress/rt_astiff \
- $(top_srcdir)/raster/test/regress/rt_asjpeg \
- $(top_srcdir)/raster/test/regress/rt_aspng \
$(top_srcdir)/raster/test/regress/rt_reclass \
$(top_srcdir)/raster/test/regress/rt_gdalwarp \
$(top_srcdir)/raster/test/regress/rt_gdalcontour \
@@ -120,21 +110,13 @@ RASTER_TEST_SREL = \
RASTER_TEST_BUGS = \
$(top_srcdir)/raster/test/regress/bug_test_car5 \
- $(top_srcdir)/raster/test/regress/permitted_gdal_drivers \
$(top_srcdir)/raster/test/regress/tickets
RASTER_TEST_LOADER = \
$(top_srcdir)/raster/test/regress/loader/Basic \
$(top_srcdir)/raster/test/regress/loader/Projected \
$(top_srcdir)/raster/test/regress/loader/BasicCopy \
- $(top_srcdir)/raster/test/regress/loader/BasicFilename \
- $(top_srcdir)/raster/test/regress/loader/BasicOutDB \
- $(top_srcdir)/raster/test/regress/loader/Tiled10x10 \
- $(top_srcdir)/raster/test/regress/loader/Tiled10x10Copy \
- $(top_srcdir)/raster/test/regress/loader/Tiled8x8 \
- $(top_srcdir)/raster/test/regress/loader/TiledAuto \
- $(top_srcdir)/raster/test/regress/loader/TiledAutoSkipNoData \
- $(top_srcdir)/raster/test/regress/loader/TiledAutoCopyn
+ $(top_srcdir)/raster/test/regress/loader/BasicFilename
RASTER_TESTS := $(RASTER_TEST_FIRST) \
$(RASTER_TEST_METADATA) $(RASTER_TEST_IO) $(RASTER_TEST_BASIC_FUNC) \
diff --git a/regress/core/binary.sql b/regress/core/binary.sql
index 7a36b65..ad78fc7 100644
--- a/regress/core/binary.sql
+++ b/regress/core/binary.sql
@@ -1,4 +1,5 @@
SET client_min_messages TO warning;
+
CREATE SCHEMA tm;
CREATE TABLE tm.geoms (id serial, g geometry);
@@ -31,24 +32,39 @@ SELECT st_force4d(g) FROM tm.geoms WHERE id < 15 ORDER BY id;
INSERT INTO tm.geoms(g)
SELECT st_setsrid(g,4326) FROM tm.geoms ORDER BY id;
-COPY tm.geoms TO :tmpfile WITH BINARY;
+-- define temp file path
+\set tmpfile '/tmp/postgis_binary_test.dat'
+
+-- export
+\set command '\\copy tm.geoms TO ':tmpfile' WITH (FORMAT BINARY)'
+:command
+
+-- import
CREATE TABLE tm.geoms_in AS SELECT * FROM tm.geoms LIMIT 0;
-COPY tm.geoms_in FROM :tmpfile WITH BINARY;
-SELECT 'geometry', count(*) FROM tm.geoms_in i, tm.geoms o WHERE i.id = o.id
- AND ST_OrderingEquals(i.g, o.g);
+\set command '\\copy tm.geoms_in FROM ':tmpfile' WITH (FORMAT BINARY)'
+:command
+
+SELECT 'geometry', count(*) FROM tm.geoms_in i, tm.geoms o
+WHERE i.id = o.id AND ST_OrderingEquals(i.g, o.g);
CREATE TABLE tm.geogs AS SELECT id,g::geography FROM tm.geoms
WHERE geometrytype(g) NOT LIKE '%CURVE%'
AND geometrytype(g) NOT LIKE '%CIRCULAR%'
AND geometrytype(g) NOT LIKE '%SURFACE%'
AND geometrytype(g) NOT LIKE 'TRIANGLE%'
- AND geometrytype(g) NOT LIKE 'TIN%'
-;
+ AND geometrytype(g) NOT LIKE 'TIN%';
-COPY tm.geogs TO :tmpfile WITH BINARY;
+-- export
+\set command '\\copy tm.geogs TO ':tmpfile' WITH (FORMAT BINARY)'
+:command
+
+-- import
CREATE TABLE tm.geogs_in AS SELECT * FROM tm.geogs LIMIT 0;
-COPY tm.geogs_in FROM :tmpfile WITH BINARY;
-SELECT 'geometry', count(*) FROM tm.geogs_in i, tm.geogs o WHERE i.id = o.id
- AND ST_OrderingEquals(i.g::geometry, o.g::geometry);
+\set command '\\copy tm.geogs_in FROM ':tmpfile' WITH (FORMAT BINARY)'
+:command
+
+SELECT 'geometry', count(*) FROM tm.geogs_in i, tm.geogs o
+WHERE i.id = o.id AND ST_OrderingEquals(i.g::geometry, o.g::geometry);
DROP SCHEMA tm CASCADE;
+
diff --git a/regress/core/tests.mk b/regress/core/tests.mk
index 90987df..74fe3f1 100644
--- a/regress/core/tests.mk
+++ b/regress/core/tests.mk
@@ -16,14 +16,13 @@ POSTGIS_PGSQL_VERSION=170
POSTGIS_GEOS_VERSION=31101
HAVE_JSON=yes
HAVE_SPGIST=yes
-INTERRUPTTESTS=yes
+INTERRUPTTESTS=no
current_dir := $(dir $(abspath $(lastword $(MAKEFILE_LIST))))
RUNTESTFLAGS_INTERNAL += \
--before-upgrade-script $(top_srcdir)/regress/hooks/hook-before-upgrade.sql \
--after-upgrade-script $(top_srcdir)/regress/hooks/hook-after-upgrade.sql \
- --after-create-script $(top_srcdir)/regress/hooks/hook-after-create.sql \
--before-uninstall-script $(top_srcdir)/regress/hooks/hook-before-uninstall.sql
TESTS += \
@@ -40,7 +39,6 @@ TESTS += \
$(top_srcdir)/regress/core/dumppoints \
$(top_srcdir)/regress/core/dumpsegments \
$(top_srcdir)/regress/core/empty \
- $(top_srcdir)/regress/core/estimatedextent \
$(top_srcdir)/regress/core/forcecurve \
$(top_srcdir)/regress/core/flatgeobuf \
$(top_srcdir)/regress/core/frechet \
@@ -60,7 +58,6 @@ TESTS += \
$(top_srcdir)/regress/core/out_marc21 \
$(top_srcdir)/regress/core/in_encodedpolyline \
$(top_srcdir)/regress/core/iscollection \
- $(top_srcdir)/regress/core/legacy \
$(top_srcdir)/regress/core/letters \
$(top_srcdir)/regress/core/lwgeom_regress \
$(top_srcdir)/regress/core/measures \
@@ -119,7 +116,6 @@ TESTS += \
$(top_srcdir)/regress/core/temporal_knn \
$(top_srcdir)/regress/core/tickets \
$(top_srcdir)/regress/core/twkb \
- $(top_srcdir)/regress/core/typmod \
$(top_srcdir)/regress/core/wkb \
$(top_srcdir)/regress/core/wkt \
$(top_srcdir)/regress/core/wmsservers \
diff --git a/regress/loader/tests.mk b/regress/loader/tests.mk
index ac4f8ad..4bad4fc 100644
--- a/regress/loader/tests.mk
+++ b/regress/loader/tests.mk
@@ -38,7 +38,5 @@ TESTS += \
$(top_srcdir)/regress/loader/Latin1 \
$(top_srcdir)/regress/loader/Latin1-implicit \
$(top_srcdir)/regress/loader/mfile \
- $(top_srcdir)/regress/loader/TestSkipANALYZE \
- $(top_srcdir)/regress/loader/TestANALYZE \
$(top_srcdir)/regress/loader/CharNoWidth \
diff --git a/regress/run_test.pl b/regress/run_test.pl
index cac4b2e..4c7c82b 100755
--- a/regress/run_test.pl
+++ b/regress/run_test.pl
@@ -238,7 +238,6 @@ $ENV{"LANG"} = "C";
# Add locale info to the psql options
# Add pg12 precision suppression
my $PGOPTIONS = $ENV{"PGOPTIONS"};
-$PGOPTIONS .= " -c lc_messages=C";
$PGOPTIONS .= " -c client_min_messages=NOTICE";
$PGOPTIONS .= " -c extra_float_digits=0";
$ENV{"PGOPTIONS"} = $PGOPTIONS;
diff --git a/topology/test/tests.mk b/topology/test/tests.mk
index cbe2633..2c7c18f 100644
--- a/topology/test/tests.mk
+++ b/topology/test/tests.mk
@@ -46,9 +46,7 @@ TESTS += \
$(top_srcdir)/topology/test/regress/legacy_query.sql \
$(top_srcdir)/topology/test/regress/legacy_validate.sql \
$(top_srcdir)/topology/test/regress/polygonize.sql \
- $(top_srcdir)/topology/test/regress/populate_topology_layer.sql \
$(top_srcdir)/topology/test/regress/removeunusedprimitives.sql \
- $(top_srcdir)/topology/test/regress/renametopogeometrycolumn.sql \
$(top_srcdir)/topology/test/regress/renametopology.sql \
$(top_srcdir)/topology/test/regress/share_sequences.sql \
$(top_srcdir)/topology/test/regress/sqlmm.sql \

File diff suppressed because one or more lines are too long

View File

@@ -1,17 +0,0 @@
#!/bin/bash
set -ex
cd "$(dirname "${0}")"
dropdb --if-exist contrib_regression
createdb contrib_regression
psql -d contrib_regression -c "ALTER DATABASE contrib_regression SET TimeZone='UTC'" \
-c "ALTER DATABASE contrib_regression SET DateStyle='ISO, MDY'" \
-c "CREATE EXTENSION postgis SCHEMA public" \
-c "CREATE EXTENSION postgis_topology" \
-c "CREATE EXTENSION postgis_tiger_geocoder CASCADE" \
-c "CREATE EXTENSION postgis_raster SCHEMA public" \
-c "CREATE EXTENSION postgis_sfcgal SCHEMA public"
patch -p1 <"postgis-common-${PG_VERSION}.patch"
patch -p1 <"postgis-regular-${PG_VERSION}.patch"
psql -d contrib_regression -f raster_outdb_template.sql
trap 'patch -R -p1 <postgis-regular-${PG_VERSION}.patch && patch -R -p1 <"postgis-common-${PG_VERSION}.patch"' EXIT
POSTGIS_REGRESS_DB=contrib_regression RUNTESTFLAGS=--nocreate make installcheck-base

View File

@@ -63,9 +63,5 @@ done
for d in ${FAILED}; do
cat "$(find $d -name regression.diffs)"
done
for postgis_diff in /tmp/pgis_reg/*_diff; do
echo "${postgis_diff}:"
cat "${postgis_diff}"
done
echo "${FAILED}"
exit 1

View File

@@ -178,9 +178,9 @@ pub struct ComputeSpec {
/// JWT for authorizing requests to endpoint storage service
pub endpoint_storage_token: Option<String>,
/// Download LFC state from endpoint_storage and pass it to Postgres on startup
/// If true, download LFC state from endpoint_storage and pass it to Postgres on startup
#[serde(default)]
pub autoprewarm: bool,
pub prewarm_lfc_on_startup: bool,
}
/// Feature flag to signal `compute_ctl` to enable certain experimental functionality.
@@ -192,9 +192,6 @@ pub enum ComputeFeature {
/// track short-lived connections as user activity.
ActivityMonitorExperimental,
/// Enable TLS functionality.
TlsExperimental,
/// This is a special feature flag that is used to represent unknown feature flags.
/// Basically all unknown to enum flags are represented as this one. See unit test
/// `parse_unknown_features()` for more details.
@@ -253,44 +250,34 @@ impl RemoteExtSpec {
}
match self.extension_data.get(real_ext_name) {
Some(_ext_data) => Ok((
real_ext_name.to_string(),
Self::build_remote_path(build_tag, pg_major_version, real_ext_name)?,
)),
Some(_ext_data) => {
// We have decided to use the Go naming convention due to Kubernetes.
let arch = match std::env::consts::ARCH {
"x86_64" => "amd64",
"aarch64" => "arm64",
arch => arch,
};
// Construct the path to the extension archive
// BUILD_TAG/PG_MAJOR_VERSION/extensions/EXTENSION_NAME.tar.zst
//
// Keep it in sync with path generation in
// https://github.com/neondatabase/build-custom-extensions/tree/main
let archive_path_str = format!(
"{build_tag}/{arch}/{pg_major_version}/extensions/{real_ext_name}.tar.zst"
);
Ok((
real_ext_name.to_string(),
RemotePath::from_string(&archive_path_str)?,
))
}
None => Err(anyhow::anyhow!(
"real_ext_name {} is not found",
real_ext_name
)),
}
}
/// Get the architecture-specific portion of the remote extension path. We
/// use the Go naming convention due to Kubernetes.
fn get_arch() -> &'static str {
match std::env::consts::ARCH {
"x86_64" => "amd64",
"aarch64" => "arm64",
arch => arch,
}
}
/// Build a [`RemotePath`] for an extension.
fn build_remote_path(
build_tag: &str,
pg_major_version: &str,
ext_name: &str,
) -> anyhow::Result<RemotePath> {
let arch = Self::get_arch();
// Construct the path to the extension archive
// BUILD_TAG/PG_MAJOR_VERSION/extensions/EXTENSION_NAME.tar.zst
//
// Keep it in sync with path generation in
// https://github.com/neondatabase/build-custom-extensions/tree/main
RemotePath::from_string(&format!(
"{build_tag}/{arch}/{pg_major_version}/extensions/{ext_name}.tar.zst"
))
}
}
#[derive(Clone, Copy, Debug, Default, Eq, PartialEq, Deserialize, Serialize)]
@@ -531,37 +518,6 @@ mod tests {
.expect("Library should be found");
}
#[test]
fn remote_extension_path() {
let rspec: RemoteExtSpec = serde_json::from_value(serde_json::json!({
"public_extensions": ["ext"],
"custom_extensions": [],
"library_index": {
"extlib": "ext",
},
"extension_data": {
"ext": {
"control_data": {
"ext.control": ""
},
"archive_path": ""
}
},
}))
.unwrap();
let (_ext_name, ext_path) = rspec
.get_ext("ext", false, "latest", "v17")
.expect("Extension should be found");
// Starting with a forward slash would have consequences for the
// Url::join() that occurs when downloading a remote extension.
assert!(!ext_path.to_string().starts_with("/"));
assert_eq!(
ext_path,
RemoteExtSpec::build_remote_path("latest", "v17", "ext").unwrap()
);
}
#[test]
fn parse_spec_file() {
let file = File::open("tests/cluster_spec.json").unwrap();

View File

@@ -85,7 +85,7 @@
"vartype": "bool"
},
{
"name": "autoprewarm",
"name": "prewarm_lfc_on_startup",
"value": "off",
"vartype": "bool"
},

View File

@@ -107,7 +107,7 @@ impl<const N: usize> MetricType for HyperLogLogState<N> {
}
impl<const N: usize> HyperLogLogState<N> {
pub fn measure(&self, item: &(impl Hash + ?Sized)) {
pub fn measure(&self, item: &impl Hash) {
// changing the hasher will break compatibility with previous measurements.
self.record(BuildHasherDefault::<xxh3::Hash64>::default().hash_one(item));
}

View File

@@ -27,7 +27,6 @@ pub use prometheus::{
pub mod launch_timestamp;
mod wrappers;
pub use prometheus;
pub use wrappers::{CountedReader, CountedWriter};
mod hll;
pub use hll::{HyperLogLog, HyperLogLogState, HyperLogLogVec};

View File

@@ -6,12 +6,8 @@ license.workspace = true
[dependencies]
thiserror.workspace = true
nix.workspace = true
nix.workspace=true
workspace_hack = { version = "0.1", path = "../../workspace_hack" }
[dev-dependencies]
rand = "0.9.1"
rand_distr = "0.5.1"
[target.'cfg(target_os = "macos")'.dependencies]
tempfile = "3.14.0"

View File

@@ -1,304 +0,0 @@
//! Hash table implementation on top of 'shmem'
//!
//! Features required in the long run by the communicator project:
//!
//! [X] Accessible from both Postgres processes and rust threads in the communicator process
//! [X] Low latency
//! [ ] Scalable to lots of concurrent accesses (currently relies on caller for locking)
//! [ ] Resizable
use std::fmt::Debug;
use std::hash::{DefaultHasher, Hash, Hasher};
use std::mem::MaybeUninit;
use crate::shmem::ShmemHandle;
mod core;
pub mod entry;
#[cfg(test)]
mod tests;
use core::CoreHashMap;
use entry::{Entry, OccupiedEntry};
#[derive(Debug)]
pub struct OutOfMemoryError();
pub struct HashMapInit<'a, K, V> {
// Hash table can be allocated in a fixed memory area, or in a resizeable ShmemHandle.
shmem_handle: Option<ShmemHandle>,
shared_ptr: *mut HashMapShared<'a, K, V>,
}
pub struct HashMapAccess<'a, K, V> {
shmem_handle: Option<ShmemHandle>,
shared_ptr: *mut HashMapShared<'a, K, V>,
}
unsafe impl<'a, K: Sync, V: Sync> Sync for HashMapAccess<'a, K, V> {}
unsafe impl<'a, K: Send, V: Send> Send for HashMapAccess<'a, K, V> {}
impl<'a, K, V> HashMapInit<'a, K, V> {
pub fn attach_writer(self) -> HashMapAccess<'a, K, V> {
HashMapAccess {
shmem_handle: self.shmem_handle,
shared_ptr: self.shared_ptr,
}
}
pub fn attach_reader(self) -> HashMapAccess<'a, K, V> {
// no difference to attach_writer currently
self.attach_writer()
}
}
/// This is stored in the shared memory area
///
/// NOTE: We carve out the parts from a contiguous chunk. Growing and shrinking the hash table
/// relies on the memory layout! The data structures are laid out in the contiguous shared memory
/// area as follows:
///
/// HashMapShared
/// [buckets]
/// [dictionary]
///
/// In between the above parts, there can be padding bytes to align the parts correctly.
struct HashMapShared<'a, K, V> {
inner: CoreHashMap<'a, K, V>,
}
impl<'a, K, V> HashMapInit<'a, K, V>
where
K: Clone + Hash + Eq,
{
pub fn estimate_size(num_buckets: u32) -> usize {
// add some margin to cover alignment etc.
CoreHashMap::<K, V>::estimate_size(num_buckets) + size_of::<HashMapShared<K, V>>() + 1000
}
pub fn init_in_fixed_area(
num_buckets: u32,
area: &'a mut [MaybeUninit<u8>],
) -> HashMapInit<'a, K, V> {
Self::init_common(num_buckets, None, area.as_mut_ptr().cast(), area.len())
}
/// Initialize a new hash map in the given shared memory area
pub fn init_in_shmem(num_buckets: u32, mut shmem: ShmemHandle) -> HashMapInit<'a, K, V> {
let size = Self::estimate_size(num_buckets);
shmem
.set_size(size)
.expect("could not resize shared memory area");
let ptr = unsafe { shmem.data_ptr.as_mut() };
Self::init_common(num_buckets, Some(shmem), ptr, size)
}
fn init_common(
num_buckets: u32,
shmem_handle: Option<ShmemHandle>,
area_ptr: *mut u8,
area_len: usize,
) -> HashMapInit<'a, K, V> {
// carve out the HashMapShared struct from the area.
let mut ptr: *mut u8 = area_ptr;
let end_ptr: *mut u8 = unsafe { area_ptr.add(area_len) };
ptr = unsafe { ptr.add(ptr.align_offset(align_of::<HashMapShared<K, V>>())) };
let shared_ptr: *mut HashMapShared<K, V> = ptr.cast();
ptr = unsafe { ptr.add(size_of::<HashMapShared<K, V>>()) };
// carve out the buckets
ptr = unsafe { ptr.byte_add(ptr.align_offset(align_of::<core::Bucket<K, V>>())) };
let buckets_ptr = ptr;
ptr = unsafe { ptr.add(size_of::<core::Bucket<K, V>>() * num_buckets as usize) };
// use remaining space for the dictionary
ptr = unsafe { ptr.byte_add(ptr.align_offset(align_of::<u32>())) };
assert!(ptr.addr() < end_ptr.addr());
let dictionary_ptr = ptr;
let dictionary_size = unsafe { end_ptr.byte_offset_from(ptr) / size_of::<u32>() as isize };
assert!(dictionary_size > 0);
let buckets =
unsafe { std::slice::from_raw_parts_mut(buckets_ptr.cast(), num_buckets as usize) };
let dictionary = unsafe {
std::slice::from_raw_parts_mut(dictionary_ptr.cast(), dictionary_size as usize)
};
let hashmap = CoreHashMap::new(buckets, dictionary);
unsafe {
std::ptr::write(shared_ptr, HashMapShared { inner: hashmap });
}
HashMapInit {
shmem_handle: shmem_handle,
shared_ptr,
}
}
}
impl<'a, K, V> HashMapAccess<'a, K, V>
where
K: Clone + Hash + Eq,
{
pub fn get_hash_value(&self, key: &K) -> u64 {
let mut hasher = DefaultHasher::new();
key.hash(&mut hasher);
hasher.finish()
}
pub fn get_with_hash<'e>(&'e self, key: &K, hash: u64) -> Option<&'e V> {
let map = unsafe { self.shared_ptr.as_ref() }.unwrap();
map.inner.get_with_hash(key, hash)
}
pub fn entry_with_hash(&mut self, key: K, hash: u64) -> Entry<'a, '_, K, V> {
let map = unsafe { self.shared_ptr.as_mut() }.unwrap();
map.inner.entry_with_hash(key, hash)
}
pub fn remove_with_hash(&mut self, key: &K, hash: u64) {
let map = unsafe { self.shared_ptr.as_mut() }.unwrap();
match map.inner.entry_with_hash(key.clone(), hash) {
Entry::Occupied(e) => {
e.remove();
}
Entry::Vacant(_) => {}
};
}
pub fn entry_at_bucket(&mut self, pos: usize) -> Option<OccupiedEntry<'a, '_, K, V>> {
let map = unsafe { self.shared_ptr.as_mut() }.unwrap();
map.inner.entry_at_bucket(pos)
}
pub fn get_num_buckets(&self) -> usize {
let map = unsafe { self.shared_ptr.as_ref() }.unwrap();
map.inner.get_num_buckets()
}
/// Return the key and value stored in bucket with given index. This can be used to
/// iterate through the hash map. (An Iterator might be nicer. The communicator's
/// clock algorithm needs to _slowly_ iterate through all buckets with its clock hand,
/// without holding a lock. If we switch to an Iterator, it must not hold the lock.)
pub fn get_at_bucket(&self, pos: usize) -> Option<&(K, V)> {
let map = unsafe { self.shared_ptr.as_ref() }.unwrap();
if pos >= map.inner.buckets.len() {
return None;
}
let bucket = &map.inner.buckets[pos];
bucket.inner.as_ref()
}
pub fn get_bucket_for_value(&self, val_ptr: *const V) -> usize {
let map = unsafe { self.shared_ptr.as_ref() }.unwrap();
let origin = map.inner.buckets.as_ptr();
let idx = (val_ptr as usize - origin as usize) / (size_of::<V>() as usize);
assert!(idx < map.inner.buckets.len());
idx
}
// for metrics
pub fn get_num_buckets_in_use(&self) -> usize {
let map = unsafe { self.shared_ptr.as_ref() }.unwrap();
map.inner.buckets_in_use as usize
}
/// Grow
///
/// 1. grow the underlying shared memory area
/// 2. Initialize new buckets. This overwrites the current dictionary
/// 3. Recalculate the dictionary
pub fn grow(&mut self, num_buckets: u32) -> Result<(), crate::shmem::Error> {
let map = unsafe { self.shared_ptr.as_mut() }.unwrap();
let inner = &mut map.inner;
let old_num_buckets = inner.buckets.len() as u32;
if num_buckets < old_num_buckets {
panic!("grow called with a smaller number of buckets");
}
if num_buckets == old_num_buckets {
return Ok(());
}
let shmem_handle = self
.shmem_handle
.as_ref()
.expect("grow called on a fixed-size hash table");
let size_bytes = HashMapInit::<K, V>::estimate_size(num_buckets);
shmem_handle.set_size(size_bytes)?;
let end_ptr: *mut u8 = unsafe { shmem_handle.data_ptr.as_ptr().add(size_bytes) };
// Initialize new buckets. The new buckets are linked to the free list. NB: This overwrites
// the dictionary!
let buckets_ptr = inner.buckets.as_mut_ptr();
unsafe {
for i in old_num_buckets..num_buckets {
let bucket_ptr = buckets_ptr.add(i as usize);
bucket_ptr.write(core::Bucket {
next: if i < num_buckets {
i as u32 + 1
} else {
inner.free_head
},
inner: None,
});
}
}
// Recalculate the dictionary
let buckets;
let dictionary;
unsafe {
let buckets_end_ptr = buckets_ptr.add(num_buckets as usize);
let dictionary_ptr: *mut u32 = buckets_end_ptr
.byte_add(buckets_end_ptr.align_offset(align_of::<u32>()))
.cast();
let dictionary_size: usize =
end_ptr.byte_offset_from(buckets_end_ptr) as usize / size_of::<u32>();
buckets = std::slice::from_raw_parts_mut(buckets_ptr, num_buckets as usize);
dictionary = std::slice::from_raw_parts_mut(dictionary_ptr, dictionary_size);
}
for i in 0..dictionary.len() {
dictionary[i] = core::INVALID_POS;
}
for i in 0..old_num_buckets as usize {
if buckets[i].inner.is_none() {
continue;
}
let mut hasher = DefaultHasher::new();
buckets[i].inner.as_ref().unwrap().0.hash(&mut hasher);
let hash = hasher.finish();
let pos: usize = (hash % dictionary.len() as u64) as usize;
buckets[i].next = dictionary[pos];
dictionary[pos] = i as u32;
}
// Finally, update the CoreHashMap struct
inner.dictionary = dictionary;
inner.buckets = buckets;
inner.free_head = old_num_buckets;
Ok(())
}
// TODO: Shrinking is a multi-step process that requires co-operation from the caller
//
// 1. The caller must first call begin_shrink(). That forbids allocation of higher-numbered
// buckets.
//
// 2. Next, the caller must evict all entries in higher-numbered buckets.
//
// 3. Finally, call finish_shrink(). This recomputes the dictionary and shrinks the underlying
// shmem area
}

View File

@@ -1,174 +0,0 @@
//! Simple hash table with chaining
//!
//! # Resizing
//!
use std::hash::Hash;
use std::mem::MaybeUninit;
use crate::hash::entry::{Entry, OccupiedEntry, PrevPos, VacantEntry};
pub(crate) const INVALID_POS: u32 = u32::MAX;
// Bucket
pub(crate) struct Bucket<K, V> {
pub(crate) next: u32,
pub(crate) inner: Option<(K, V)>,
}
pub(crate) struct CoreHashMap<'a, K, V> {
pub(crate) dictionary: &'a mut [u32],
pub(crate) buckets: &'a mut [Bucket<K, V>],
pub(crate) free_head: u32,
pub(crate) _user_list_head: u32,
// metrics
pub(crate) buckets_in_use: u32,
}
#[derive(Debug)]
pub struct FullError();
impl<'a, K: Hash + Eq, V> CoreHashMap<'a, K, V>
where
K: Clone + Hash + Eq,
{
const FILL_FACTOR: f32 = 0.60;
pub fn estimate_size(num_buckets: u32) -> usize {
let mut size = 0;
// buckets
size += size_of::<Bucket<K, V>>() * num_buckets as usize;
// dictionary
size += (f32::ceil((size_of::<u32>() * num_buckets as usize) as f32 / Self::FILL_FACTOR))
as usize;
size
}
pub fn new(
buckets: &'a mut [MaybeUninit<Bucket<K, V>>],
dictionary: &'a mut [MaybeUninit<u32>],
) -> CoreHashMap<'a, K, V> {
// Initialize the buckets
for i in 0..buckets.len() {
buckets[i].write(Bucket {
next: if i < buckets.len() - 1 {
i as u32 + 1
} else {
INVALID_POS
},
inner: None,
});
}
// Initialize the dictionary
for i in 0..dictionary.len() {
dictionary[i].write(INVALID_POS);
}
// TODO: use std::slice::assume_init_mut() once it stabilizes
let buckets =
unsafe { std::slice::from_raw_parts_mut(buckets.as_mut_ptr().cast(), buckets.len()) };
let dictionary = unsafe {
std::slice::from_raw_parts_mut(dictionary.as_mut_ptr().cast(), dictionary.len())
};
CoreHashMap {
dictionary,
buckets,
free_head: 0,
buckets_in_use: 0,
_user_list_head: INVALID_POS,
}
}
pub fn get_with_hash(&self, key: &K, hash: u64) -> Option<&V> {
let mut next = self.dictionary[hash as usize % self.dictionary.len()];
loop {
if next == INVALID_POS {
return None;
}
let bucket = &self.buckets[next as usize];
let (bucket_key, bucket_value) = bucket.inner.as_ref().expect("entry is in use");
if bucket_key == key {
return Some(&bucket_value);
}
next = bucket.next;
}
}
// all updates are done through Entry
pub fn entry_with_hash(&mut self, key: K, hash: u64) -> Entry<'a, '_, K, V> {
let dict_pos = hash as usize % self.dictionary.len();
let first = self.dictionary[dict_pos];
if first == INVALID_POS {
// no existing entry
return Entry::Vacant(VacantEntry {
map: self,
key,
dict_pos: dict_pos as u32,
});
}
let mut prev_pos = PrevPos::First(dict_pos as u32);
let mut next = first;
loop {
let bucket = &mut self.buckets[next as usize];
let (bucket_key, _bucket_value) = bucket.inner.as_mut().expect("entry is in use");
if *bucket_key == key {
// found existing entry
return Entry::Occupied(OccupiedEntry {
map: self,
_key: key,
prev_pos,
bucket_pos: next,
});
}
if bucket.next == INVALID_POS {
// No existing entry
return Entry::Vacant(VacantEntry {
map: self,
key,
dict_pos: dict_pos as u32,
});
}
prev_pos = PrevPos::Chained(next);
next = bucket.next;
}
}
pub fn get_num_buckets(&self) -> usize {
self.buckets.len()
}
pub fn entry_at_bucket(&mut self, pos: usize) -> Option<OccupiedEntry<K, V>> {
if pos >= self.buckets.len() {
return None;
}
todo!()
//self.buckets[pos].inner.as_ref()
}
pub(crate) fn alloc_bucket(&mut self, key: K, value: V) -> Result<u32, FullError> {
let pos = self.free_head;
if pos == INVALID_POS {
return Err(FullError());
}
let bucket = &mut self.buckets[pos as usize];
self.free_head = bucket.next;
self.buckets_in_use += 1;
bucket.next = INVALID_POS;
bucket.inner = Some((key, value));
return Ok(pos);
}
}

View File

@@ -1,91 +0,0 @@
//! Like std::collections::hash_map::Entry;
use crate::hash::core::{CoreHashMap, FullError, INVALID_POS};
use std::hash::Hash;
use std::mem;
pub enum Entry<'a, 'b, K, V> {
Occupied(OccupiedEntry<'a, 'b, K, V>),
Vacant(VacantEntry<'a, 'b, K, V>),
}
pub(crate) enum PrevPos {
First(u32),
Chained(u32),
}
pub struct OccupiedEntry<'a, 'b, K, V> {
pub(crate) map: &'b mut CoreHashMap<'a, K, V>,
pub(crate) _key: K, // The key of the occupied entry
pub(crate) prev_pos: PrevPos,
pub(crate) bucket_pos: u32, // The position of the bucket in the CoreHashMap's buckets array
}
impl<'a, 'b, K, V> OccupiedEntry<'a, 'b, K, V> {
pub fn get(&self) -> &V {
&self.map.buckets[self.bucket_pos as usize]
.inner
.as_ref()
.unwrap()
.1
}
pub fn get_mut(&mut self) -> &mut V {
&mut self.map.buckets[self.bucket_pos as usize]
.inner
.as_mut()
.unwrap()
.1
}
pub fn insert(&mut self, value: V) -> V {
let bucket = &mut self.map.buckets[self.bucket_pos as usize];
// This assumes inner is Some, which it must be for an OccupiedEntry
let old_value = mem::replace(&mut bucket.inner.as_mut().unwrap().1, value);
old_value
}
pub fn remove(self) -> V {
// CoreHashMap::remove returns Option<(K, V)>. We know it's Some for an OccupiedEntry.
let bucket = &mut self.map.buckets[self.bucket_pos as usize];
// unlink it from the chain
match self.prev_pos {
PrevPos::First(dict_pos) => self.map.dictionary[dict_pos as usize] = bucket.next,
PrevPos::Chained(bucket_pos) => {
self.map.buckets[bucket_pos as usize].next = bucket.next
}
}
// and add it to the freelist
let bucket = &mut self.map.buckets[self.bucket_pos as usize];
let old_value = bucket.inner.take();
bucket.next = self.map.free_head;
self.map.free_head = self.bucket_pos;
self.map.buckets_in_use -= 1;
return old_value.unwrap().1;
}
}
pub struct VacantEntry<'a, 'b, K, V> {
pub(crate) map: &'b mut CoreHashMap<'a, K, V>,
pub(crate) key: K, // The key to insert
pub(crate) dict_pos: u32,
}
impl<'a, 'b, K: Clone + Hash + Eq, V> VacantEntry<'a, 'b, K, V> {
pub fn insert(self, value: V) -> Result<&'b mut V, FullError> {
let pos = self.map.alloc_bucket(self.key, value)?;
if pos == INVALID_POS {
return Err(FullError());
}
let bucket = &mut self.map.buckets[pos as usize];
bucket.next = self.map.dictionary[self.dict_pos as usize];
self.map.dictionary[self.dict_pos as usize] = pos;
let result = &mut self.map.buckets[pos as usize].inner.as_mut().unwrap().1;
return Ok(result);
}
}

View File

@@ -1,220 +0,0 @@
use std::collections::BTreeMap;
use std::collections::HashSet;
use std::fmt::{Debug, Formatter};
use std::sync::atomic::{AtomicUsize, Ordering};
use crate::hash::HashMapAccess;
use crate::hash::HashMapInit;
use crate::hash::UpdateAction;
use crate::shmem::ShmemHandle;
use rand::seq::SliceRandom;
use rand::{Rng, RngCore};
use rand_distr::Zipf;
const TEST_KEY_LEN: usize = 16;
#[derive(Clone, Copy, Debug, Hash, PartialEq, Eq, PartialOrd, Ord)]
struct TestKey([u8; TEST_KEY_LEN]);
impl From<&TestKey> for u128 {
fn from(val: &TestKey) -> u128 {
u128::from_be_bytes(val.0)
}
}
impl From<u128> for TestKey {
fn from(val: u128) -> TestKey {
TestKey(val.to_be_bytes())
}
}
impl<'a> From<&'a [u8]> for TestKey {
fn from(bytes: &'a [u8]) -> TestKey {
TestKey(bytes.try_into().unwrap())
}
}
fn test_inserts<K: Into<TestKey> + Copy>(keys: &[K]) {
const MAX_MEM_SIZE: usize = 10000000;
let shmem = ShmemHandle::new("test_inserts", 0, MAX_MEM_SIZE).unwrap();
let init_struct = HashMapInit::<TestKey, usize>::init_in_shmem(100000, shmem);
let w = init_struct.attach_writer();
for (idx, k) in keys.iter().enumerate() {
let res = w.insert(&(*k).into(), idx);
assert!(res.is_ok());
}
for (idx, k) in keys.iter().enumerate() {
let x = w.get(&(*k).into());
let value = x.as_deref().copied();
assert_eq!(value, Some(idx));
}
//eprintln!("stats: {:?}", tree_writer.get_statistics());
}
#[test]
fn dense() {
// This exercises splitting a node with prefix
let keys: &[u128] = &[0, 1, 2, 3, 256];
test_inserts(keys);
// Dense keys
let mut keys: Vec<u128> = (0..10000).collect();
test_inserts(&keys);
// Do the same in random orders
for _ in 1..10 {
keys.shuffle(&mut rand::rng());
test_inserts(&keys);
}
}
#[test]
fn sparse() {
// sparse keys
let mut keys: Vec<TestKey> = Vec::new();
let mut used_keys = HashSet::new();
for _ in 0..10000 {
loop {
let key = rand::random::<u128>();
if used_keys.get(&key).is_some() {
continue;
}
used_keys.insert(key);
keys.push(key.into());
break;
}
}
test_inserts(&keys);
}
struct TestValue(AtomicUsize);
impl TestValue {
fn new(val: usize) -> TestValue {
TestValue(AtomicUsize::new(val))
}
fn load(&self) -> usize {
self.0.load(Ordering::Relaxed)
}
}
impl Clone for TestValue {
fn clone(&self) -> TestValue {
TestValue::new(self.load())
}
}
impl Debug for TestValue {
fn fmt(&self, fmt: &mut Formatter<'_>) -> Result<(), std::fmt::Error> {
write!(fmt, "{:?}", self.load())
}
}
#[derive(Clone, Debug)]
struct TestOp(TestKey, Option<usize>);
fn apply_op(
op: &TestOp,
sut: &HashMapAccess<TestKey, TestValue>,
shadow: &mut BTreeMap<TestKey, usize>,
) {
eprintln!("applying op: {op:?}");
// apply the change to the shadow tree first
let shadow_existing = if let Some(v) = op.1 {
shadow.insert(op.0, v)
} else {
shadow.remove(&op.0)
};
// apply to Art tree
sut.update_with_fn(&op.0, |existing| {
assert_eq!(existing.map(TestValue::load), shadow_existing);
match (existing, op.1) {
(None, None) => UpdateAction::Nothing,
(None, Some(new_val)) => UpdateAction::Insert(TestValue::new(new_val)),
(Some(_old_val), None) => UpdateAction::Remove,
(Some(old_val), Some(new_val)) => {
old_val.0.store(new_val, Ordering::Relaxed);
UpdateAction::Nothing
}
}
})
.expect("out of memory");
}
#[test]
fn random_ops() {
const MAX_MEM_SIZE: usize = 10000000;
let shmem = ShmemHandle::new("test_inserts", 0, MAX_MEM_SIZE).unwrap();
let init_struct = HashMapInit::<TestKey, TestValue>::init_in_shmem(100000, shmem);
let writer = init_struct.attach_writer();
let mut shadow: std::collections::BTreeMap<TestKey, usize> = BTreeMap::new();
let distribution = Zipf::new(u128::MAX as f64, 1.1).unwrap();
let mut rng = rand::rng();
for i in 0..100000 {
let key: TestKey = (rng.sample(distribution) as u128).into();
let op = TestOp(key, if rng.random_bool(0.75) { Some(i) } else { None });
apply_op(&op, &writer, &mut shadow);
if i % 1000 == 0 {
eprintln!("{i} ops processed");
//eprintln!("stats: {:?}", tree_writer.get_statistics());
//test_iter(&tree_writer, &shadow);
}
}
}
#[test]
fn test_grow() {
const MEM_SIZE: usize = 10000000;
let shmem = ShmemHandle::new("test_grow", 0, MEM_SIZE).unwrap();
let init_struct = HashMapInit::<TestKey, TestValue>::init_in_shmem(1000, shmem);
let writer = init_struct.attach_writer();
let mut shadow: std::collections::BTreeMap<TestKey, usize> = BTreeMap::new();
let mut rng = rand::rng();
for i in 0..10000 {
let key: TestKey = ((rng.next_u32() % 1000) as u128).into();
let op = TestOp(key, if rng.random_bool(0.75) { Some(i) } else { None });
apply_op(&op, &writer, &mut shadow);
if i % 1000 == 0 {
eprintln!("{i} ops processed");
//eprintln!("stats: {:?}", tree_writer.get_statistics());
//test_iter(&tree_writer, &shadow);
}
}
writer.grow(1500).unwrap();
for i in 0..10000 {
let key: TestKey = ((rng.next_u32() % 1500) as u128).into();
let op = TestOp(key, if rng.random_bool(0.75) { Some(i) } else { None });
apply_op(&op, &writer, &mut shadow);
if i % 1000 == 0 {
eprintln!("{i} ops processed");
//eprintln!("stats: {:?}", tree_writer.get_statistics());
//test_iter(&tree_writer, &shadow);
}
}
}

View File

@@ -1,4 +1,418 @@
//! Shared memory utilities for neon communicator
pub mod hash;
pub mod shmem;
use std::num::NonZeroUsize;
use std::os::fd::{AsFd, BorrowedFd, OwnedFd};
use std::ptr::NonNull;
use std::sync::atomic::{AtomicUsize, Ordering};
use nix::errno::Errno;
use nix::sys::mman::MapFlags;
use nix::sys::mman::ProtFlags;
use nix::sys::mman::mmap as nix_mmap;
use nix::sys::mman::munmap as nix_munmap;
use nix::unistd::ftruncate as nix_ftruncate;
/// ShmemHandle represents a shared memory area that can be shared by processes over fork().
/// Unlike shared memory allocated by Postgres, this area is resizable, up to 'max_size' that's
/// specified at creation.
///
/// The area is backed by an anonymous file created with memfd_create(). The full address space for
/// 'max_size' is reserved up-front with mmap(), but whenever you call [`ShmemHandle::set_size`],
/// the underlying file is resized. Do not access the area beyond the current size. Currently, that
/// will cause the file to be expanded, but we might use mprotect() etc. to enforce that in the
/// future.
pub struct ShmemHandle {
/// memfd file descriptor
fd: OwnedFd,
max_size: usize,
// Pointer to the beginning of the shared memory area. The header is stored there.
shared_ptr: NonNull<SharedStruct>,
// Pointer to the beginning of the user data
pub data_ptr: NonNull<u8>,
}
/// This is stored at the beginning in the shared memory area.
struct SharedStruct {
max_size: usize,
/// Current size of the backing file. The high-order bit is used for the RESIZE_IN_PROGRESS flag
current_size: AtomicUsize,
}
const RESIZE_IN_PROGRESS: usize = 1 << 63;
const HEADER_SIZE: usize = std::mem::size_of::<SharedStruct>();
/// Error type returned by the ShmemHandle functions.
#[derive(thiserror::Error, Debug)]
#[error("{msg}: {errno}")]
pub struct Error {
pub msg: String,
pub errno: Errno,
}
impl Error {
fn new(msg: &str, errno: Errno) -> Error {
Error {
msg: msg.to_string(),
errno,
}
}
}
impl ShmemHandle {
/// Create a new shared memory area. To communicate between processes, the processes need to be
/// fork()'d after calling this, so that the ShmemHandle is inherited by all processes.
///
/// If the ShmemHandle is dropped, the memory is unmapped from the current process. Other
/// processes can continue using it, however.
pub fn new(name: &str, initial_size: usize, max_size: usize) -> Result<ShmemHandle, Error> {
// create the backing anonymous file.
let fd = create_backing_file(name)?;
Self::new_with_fd(fd, initial_size, max_size)
}
fn new_with_fd(
fd: OwnedFd,
initial_size: usize,
max_size: usize,
) -> Result<ShmemHandle, Error> {
// We reserve the high-order bit for the RESIZE_IN_PROGRESS flag, and the actual size
// is a little larger than this because of the SharedStruct header. Make the upper limit
// somewhat smaller than that, because with anything close to that, you'll run out of
// memory anyway.
if max_size >= 1 << 48 {
panic!("max size {} too large", max_size);
}
if initial_size > max_size {
panic!("initial size {initial_size} larger than max size {max_size}");
}
// The actual initial / max size is the one given by the caller, plus the size of
// 'SharedStruct'.
let initial_size = HEADER_SIZE + initial_size;
let max_size = NonZeroUsize::new(HEADER_SIZE + max_size).unwrap();
// Reserve address space for it with mmap
//
// TODO: Use MAP_HUGETLB if possible
let start_ptr = unsafe {
nix_mmap(
None,
max_size,
ProtFlags::PROT_READ | ProtFlags::PROT_WRITE,
MapFlags::MAP_SHARED,
&fd,
0,
)
}
.map_err(|e| Error::new("mmap failed: {e}", e))?;
// Reserve space for the initial size
enlarge_file(fd.as_fd(), initial_size as u64)?;
// Initialize the header
let shared: NonNull<SharedStruct> = start_ptr.cast();
unsafe {
shared.write(SharedStruct {
max_size: max_size.into(),
current_size: AtomicUsize::new(initial_size),
})
};
// The user data begins after the header
let data_ptr = unsafe { start_ptr.cast().add(HEADER_SIZE) };
Ok(ShmemHandle {
fd,
max_size: max_size.into(),
shared_ptr: shared,
data_ptr,
})
}
// return reference to the header
fn shared(&self) -> &SharedStruct {
unsafe { self.shared_ptr.as_ref() }
}
/// Resize the shared memory area. 'new_size' must not be larger than the 'max_size' specified
/// when creating the area.
///
/// This may only be called from one process/thread concurrently. We detect that case
/// and return an Error.
pub fn set_size(&self, new_size: usize) -> Result<(), Error> {
let new_size = new_size + HEADER_SIZE;
let shared = self.shared();
if new_size > self.max_size {
panic!(
"new size ({} is greater than max size ({})",
new_size, self.max_size
);
}
assert_eq!(self.max_size, shared.max_size);
// Lock the area by setting the bit in 'current_size'
//
// Ordering::Relaxed would probably be sufficient here, as we don't access any other memory
// and the posix_fallocate/ftruncate call is surely a synchronization point anyway. But
// since this is not performance-critical, better safe than sorry .
let mut old_size = shared.current_size.load(Ordering::Acquire);
loop {
if (old_size & RESIZE_IN_PROGRESS) != 0 {
return Err(Error::new(
"concurrent resize detected",
Errno::UnknownErrno,
));
}
match shared.current_size.compare_exchange(
old_size,
new_size,
Ordering::Acquire,
Ordering::Relaxed,
) {
Ok(_) => break,
Err(x) => old_size = x,
}
}
// Ok, we got the lock.
//
// NB: If anything goes wrong, we *must* clear the bit!
let result = {
use std::cmp::Ordering::{Equal, Greater, Less};
match new_size.cmp(&old_size) {
Less => nix_ftruncate(&self.fd, new_size as i64).map_err(|e| {
Error::new("could not shrink shmem segment, ftruncate failed: {e}", e)
}),
Equal => Ok(()),
Greater => enlarge_file(self.fd.as_fd(), new_size as u64),
}
};
// Unlock
shared.current_size.store(
if result.is_ok() { new_size } else { old_size },
Ordering::Release,
);
result
}
/// Returns the current user-visible size of the shared memory segment.
///
/// NOTE: a concurrent set_size() call can change the size at any time. It is the caller's
/// responsibility not to access the area beyond the current size.
pub fn current_size(&self) -> usize {
let total_current_size =
self.shared().current_size.load(Ordering::Relaxed) & !RESIZE_IN_PROGRESS;
total_current_size - HEADER_SIZE
}
}
impl Drop for ShmemHandle {
fn drop(&mut self) {
// SAFETY: The pointer was obtained from mmap() with the given size.
// We unmap the entire region.
let _ = unsafe { nix_munmap(self.shared_ptr.cast(), self.max_size) };
// The fd is dropped automatically by OwnedFd.
}
}
/// Create a "backing file" for the shared memory area. On Linux, use memfd_create(), to create an
/// anonymous in-memory file. One macos, fall back to a regular file. That's good enough for
/// development and testing, but in production we want the file to stay in memory.
///
/// disable 'unused_variables' warnings, because in the macos path, 'name' is unused.
#[allow(unused_variables)]
fn create_backing_file(name: &str) -> Result<OwnedFd, Error> {
#[cfg(not(target_os = "macos"))]
{
nix::sys::memfd::memfd_create(name, nix::sys::memfd::MFdFlags::empty())
.map_err(|e| Error::new("memfd_create failed: {e}", e))
}
#[cfg(target_os = "macos")]
{
let file = tempfile::tempfile().map_err(|e| {
Error::new(
"could not create temporary file to back shmem area: {e}",
nix::errno::Errno::from_raw(e.raw_os_error().unwrap_or(0)),
)
})?;
Ok(OwnedFd::from(file))
}
}
fn enlarge_file(fd: BorrowedFd, size: u64) -> Result<(), Error> {
// Use posix_fallocate() to enlarge the file. It reserves the space correctly, so that
// we don't get a segfault later when trying to actually use it.
#[cfg(not(target_os = "macos"))]
{
nix::fcntl::posix_fallocate(fd, 0, size as i64).map_err(|e| {
Error::new(
"could not grow shmem segment, posix_fallocate failed: {e}",
e,
)
})
}
// As a fallback on macos, which doesn't have posix_fallocate, use plain 'fallocate'
#[cfg(target_os = "macos")]
{
nix::unistd::ftruncate(fd, size as i64)
.map_err(|e| Error::new("could not grow shmem segment, ftruncate failed: {e}", e))
}
}
#[cfg(test)]
mod tests {
use super::*;
use nix::unistd::ForkResult;
use std::ops::Range;
/// check that all bytes in given range have the expected value.
fn assert_range(ptr: *const u8, expected: u8, range: Range<usize>) {
for i in range {
let b = unsafe { *(ptr.add(i)) };
assert_eq!(expected, b, "unexpected byte at offset {}", i);
}
}
/// Write 'b' to all bytes in the given range
fn write_range(ptr: *mut u8, b: u8, range: Range<usize>) {
unsafe { std::ptr::write_bytes(ptr.add(range.start), b, range.end - range.start) };
}
// simple single-process test of growing and shrinking
#[test]
fn test_shmem_resize() -> Result<(), Error> {
let max_size = 1024 * 1024;
let init_struct = ShmemHandle::new("test_shmem_resize", 0, max_size)?;
assert_eq!(init_struct.current_size(), 0);
// Initial grow
let size1 = 10000;
init_struct.set_size(size1).unwrap();
assert_eq!(init_struct.current_size(), size1);
// Write some data
let data_ptr = init_struct.data_ptr.as_ptr();
write_range(data_ptr, 0xAA, 0..size1);
assert_range(data_ptr, 0xAA, 0..size1);
// Shrink
let size2 = 5000;
init_struct.set_size(size2).unwrap();
assert_eq!(init_struct.current_size(), size2);
// Grow again
let size3 = 20000;
init_struct.set_size(size3).unwrap();
assert_eq!(init_struct.current_size(), size3);
// Try to read it. The area that was shrunk and grown again should read as all zeros now
assert_range(data_ptr, 0xAA, 0..5000);
assert_range(data_ptr, 0, 5000..size1);
// Try to grow beyond max_size
//let size4 = max_size + 1;
//assert!(init_struct.set_size(size4).is_err());
// Dropping init_struct should unmap the memory
drop(init_struct);
Ok(())
}
/// This is used in tests to coordinate between test processes. It's like std::sync::Barrier,
/// but is stored in the shared memory area and works across processes. It's implemented by
/// polling, because e.g. standard rust mutexes are not guaranteed to work across processes.
struct SimpleBarrier {
num_procs: usize,
count: AtomicUsize,
}
impl SimpleBarrier {
unsafe fn init(ptr: *mut SimpleBarrier, num_procs: usize) {
unsafe {
*ptr = SimpleBarrier {
num_procs,
count: AtomicUsize::new(0),
}
}
}
pub fn wait(&self) {
let old = self.count.fetch_add(1, Ordering::Relaxed);
let generation = old / self.num_procs;
let mut current = old + 1;
while current < (generation + 1) * self.num_procs {
std::thread::sleep(std::time::Duration::from_millis(10));
current = self.count.load(Ordering::Relaxed);
}
}
}
#[test]
fn test_multi_process() {
// Initialize
let max_size = 1_000_000_000_000;
let init_struct = ShmemHandle::new("test_multi_process", 0, max_size).unwrap();
let ptr = init_struct.data_ptr.as_ptr();
// Store the SimpleBarrier in the first 1k of the area.
init_struct.set_size(10000).unwrap();
let barrier_ptr: *mut SimpleBarrier = unsafe {
ptr.add(ptr.align_offset(std::mem::align_of::<SimpleBarrier>()))
.cast()
};
unsafe { SimpleBarrier::init(barrier_ptr, 2) };
let barrier = unsafe { barrier_ptr.as_ref().unwrap() };
// Fork another test process. The code after this runs in both processes concurrently.
let fork_result = unsafe { nix::unistd::fork().unwrap() };
// In the parent, fill bytes between 1000..2000. In the child, between 2000..3000
if fork_result.is_parent() {
write_range(ptr, 0xAA, 1000..2000);
} else {
write_range(ptr, 0xBB, 2000..3000);
}
barrier.wait();
// Verify the contents. (in both processes)
assert_range(ptr, 0xAA, 1000..2000);
assert_range(ptr, 0xBB, 2000..3000);
// Grow, from the child this time
let size = 10_000_000;
if !fork_result.is_parent() {
init_struct.set_size(size).unwrap();
}
barrier.wait();
// make some writes at the end
if fork_result.is_parent() {
write_range(ptr, 0xAA, (size - 10)..size);
} else {
write_range(ptr, 0xBB, (size - 20)..(size - 10));
}
barrier.wait();
// Verify the contents. (This runs in both processes)
assert_range(ptr, 0, (size - 1000)..(size - 20));
assert_range(ptr, 0xBB, (size - 20)..(size - 10));
assert_range(ptr, 0xAA, (size - 10)..size);
if let ForkResult::Parent { child } = fork_result {
nix::sys::wait::waitpid(child, None).unwrap();
}
}
}

View File

@@ -1,418 +0,0 @@
//! Dynamically resizable contiguous chunk of shared memory
use std::num::NonZeroUsize;
use std::os::fd::{AsFd, BorrowedFd, OwnedFd};
use std::ptr::NonNull;
use std::sync::atomic::{AtomicUsize, Ordering};
use nix::errno::Errno;
use nix::sys::mman::MapFlags;
use nix::sys::mman::ProtFlags;
use nix::sys::mman::mmap as nix_mmap;
use nix::sys::mman::munmap as nix_munmap;
use nix::unistd::ftruncate as nix_ftruncate;
/// ShmemHandle represents a shared memory area that can be shared by processes over fork().
/// Unlike shared memory allocated by Postgres, this area is resizable, up to 'max_size' that's
/// specified at creation.
///
/// The area is backed by an anonymous file created with memfd_create(). The full address space for
/// 'max_size' is reserved up-front with mmap(), but whenever you call [`ShmemHandle::set_size`],
/// the underlying file is resized. Do not access the area beyond the current size. Currently, that
/// will cause the file to be expanded, but we might use mprotect() etc. to enforce that in the
/// future.
pub struct ShmemHandle {
/// memfd file descriptor
fd: OwnedFd,
max_size: usize,
// Pointer to the beginning of the shared memory area. The header is stored there.
shared_ptr: NonNull<SharedStruct>,
// Pointer to the beginning of the user data
pub data_ptr: NonNull<u8>,
}
/// This is stored at the beginning in the shared memory area.
struct SharedStruct {
max_size: usize,
/// Current size of the backing file. The high-order bit is used for the RESIZE_IN_PROGRESS flag
current_size: AtomicUsize,
}
const RESIZE_IN_PROGRESS: usize = 1 << 63;
const HEADER_SIZE: usize = std::mem::size_of::<SharedStruct>();
/// Error type returned by the ShmemHandle functions.
#[derive(thiserror::Error, Debug)]
#[error("{msg}: {errno}")]
pub struct Error {
pub msg: String,
pub errno: Errno,
}
impl Error {
fn new(msg: &str, errno: Errno) -> Error {
Error {
msg: msg.to_string(),
errno,
}
}
}
impl ShmemHandle {
/// Create a new shared memory area. To communicate between processes, the processes need to be
/// fork()'d after calling this, so that the ShmemHandle is inherited by all processes.
///
/// If the ShmemHandle is dropped, the memory is unmapped from the current process. Other
/// processes can continue using it, however.
pub fn new(name: &str, initial_size: usize, max_size: usize) -> Result<ShmemHandle, Error> {
// create the backing anonymous file.
let fd = create_backing_file(name)?;
Self::new_with_fd(fd, initial_size, max_size)
}
fn new_with_fd(
fd: OwnedFd,
initial_size: usize,
max_size: usize,
) -> Result<ShmemHandle, Error> {
// We reserve the high-order bit for the RESIZE_IN_PROGRESS flag, and the actual size
// is a little larger than this because of the SharedStruct header. Make the upper limit
// somewhat smaller than that, because with anything close to that, you'll run out of
// memory anyway.
if max_size >= 1 << 48 {
panic!("max size {} too large", max_size);
}
if initial_size > max_size {
panic!("initial size {initial_size} larger than max size {max_size}");
}
// The actual initial / max size is the one given by the caller, plus the size of
// 'SharedStruct'.
let initial_size = HEADER_SIZE + initial_size;
let max_size = NonZeroUsize::new(HEADER_SIZE + max_size).unwrap();
// Reserve address space for it with mmap
//
// TODO: Use MAP_HUGETLB if possible
let start_ptr = unsafe {
nix_mmap(
None,
max_size,
ProtFlags::PROT_READ | ProtFlags::PROT_WRITE,
MapFlags::MAP_SHARED,
&fd,
0,
)
}
.map_err(|e| Error::new("mmap failed: {e}", e))?;
// Reserve space for the initial size
enlarge_file(fd.as_fd(), initial_size as u64)?;
// Initialize the header
let shared: NonNull<SharedStruct> = start_ptr.cast();
unsafe {
shared.write(SharedStruct {
max_size: max_size.into(),
current_size: AtomicUsize::new(initial_size),
})
};
// The user data begins after the header
let data_ptr = unsafe { start_ptr.cast().add(HEADER_SIZE) };
Ok(ShmemHandle {
fd,
max_size: max_size.into(),
shared_ptr: shared,
data_ptr,
})
}
// return reference to the header
fn shared(&self) -> &SharedStruct {
unsafe { self.shared_ptr.as_ref() }
}
/// Resize the shared memory area. 'new_size' must not be larger than the 'max_size' specified
/// when creating the area.
///
/// This may only be called from one process/thread concurrently. We detect that case
/// and return an Error.
pub fn set_size(&self, new_size: usize) -> Result<(), Error> {
let new_size = new_size + HEADER_SIZE;
let shared = self.shared();
if new_size > self.max_size {
panic!(
"new size ({} is greater than max size ({})",
new_size, self.max_size
);
}
assert_eq!(self.max_size, shared.max_size);
// Lock the area by setting the bit in 'current_size'
//
// Ordering::Relaxed would probably be sufficient here, as we don't access any other memory
// and the posix_fallocate/ftruncate call is surely a synchronization point anyway. But
// since this is not performance-critical, better safe than sorry .
let mut old_size = shared.current_size.load(Ordering::Acquire);
loop {
if (old_size & RESIZE_IN_PROGRESS) != 0 {
return Err(Error::new(
"concurrent resize detected",
Errno::UnknownErrno,
));
}
match shared.current_size.compare_exchange(
old_size,
new_size,
Ordering::Acquire,
Ordering::Relaxed,
) {
Ok(_) => break,
Err(x) => old_size = x,
}
}
// Ok, we got the lock.
//
// NB: If anything goes wrong, we *must* clear the bit!
let result = {
use std::cmp::Ordering::{Equal, Greater, Less};
match new_size.cmp(&old_size) {
Less => nix_ftruncate(&self.fd, new_size as i64).map_err(|e| {
Error::new("could not shrink shmem segment, ftruncate failed: {e}", e)
}),
Equal => Ok(()),
Greater => enlarge_file(self.fd.as_fd(), new_size as u64),
}
};
// Unlock
shared.current_size.store(
if result.is_ok() { new_size } else { old_size },
Ordering::Release,
);
result
}
/// Returns the current user-visible size of the shared memory segment.
///
/// NOTE: a concurrent set_size() call can change the size at any time. It is the caller's
/// responsibility not to access the area beyond the current size.
pub fn current_size(&self) -> usize {
let total_current_size =
self.shared().current_size.load(Ordering::Relaxed) & !RESIZE_IN_PROGRESS;
total_current_size - HEADER_SIZE
}
}
impl Drop for ShmemHandle {
fn drop(&mut self) {
// SAFETY: The pointer was obtained from mmap() with the given size.
// We unmap the entire region.
let _ = unsafe { nix_munmap(self.shared_ptr.cast(), self.max_size) };
// The fd is dropped automatically by OwnedFd.
}
}
/// Create a "backing file" for the shared memory area. On Linux, use memfd_create(), to create an
/// anonymous in-memory file. One macos, fall back to a regular file. That's good enough for
/// development and testing, but in production we want the file to stay in memory.
///
/// disable 'unused_variables' warnings, because in the macos path, 'name' is unused.
#[allow(unused_variables)]
fn create_backing_file(name: &str) -> Result<OwnedFd, Error> {
#[cfg(not(target_os = "macos"))]
{
nix::sys::memfd::memfd_create(name, nix::sys::memfd::MFdFlags::empty())
.map_err(|e| Error::new("memfd_create failed: {e}", e))
}
#[cfg(target_os = "macos")]
{
let file = tempfile::tempfile().map_err(|e| {
Error::new(
"could not create temporary file to back shmem area: {e}",
nix::errno::Errno::from_raw(e.raw_os_error().unwrap_or(0)),
)
})?;
Ok(OwnedFd::from(file))
}
}
fn enlarge_file(fd: BorrowedFd, size: u64) -> Result<(), Error> {
// Use posix_fallocate() to enlarge the file. It reserves the space correctly, so that
// we don't get a segfault later when trying to actually use it.
#[cfg(not(target_os = "macos"))]
{
nix::fcntl::posix_fallocate(fd, 0, size as i64).map_err(|e| {
Error::new(
"could not grow shmem segment, posix_fallocate failed: {e}",
e,
)
})
}
// As a fallback on macos, which doesn't have posix_fallocate, use plain 'fallocate'
#[cfg(target_os = "macos")]
{
nix::unistd::ftruncate(fd, size as i64)
.map_err(|e| Error::new("could not grow shmem segment, ftruncate failed: {e}", e))
}
}
#[cfg(test)]
mod tests {
use super::*;
use nix::unistd::ForkResult;
use std::ops::Range;
/// check that all bytes in given range have the expected value.
fn assert_range(ptr: *const u8, expected: u8, range: Range<usize>) {
for i in range {
let b = unsafe { *(ptr.add(i)) };
assert_eq!(expected, b, "unexpected byte at offset {}", i);
}
}
/// Write 'b' to all bytes in the given range
fn write_range(ptr: *mut u8, b: u8, range: Range<usize>) {
unsafe { std::ptr::write_bytes(ptr.add(range.start), b, range.end - range.start) };
}
// simple single-process test of growing and shrinking
#[test]
fn test_shmem_resize() -> Result<(), Error> {
let max_size = 1024 * 1024;
let init_struct = ShmemHandle::new("test_shmem_resize", 0, max_size)?;
assert_eq!(init_struct.current_size(), 0);
// Initial grow
let size1 = 10000;
init_struct.set_size(size1).unwrap();
assert_eq!(init_struct.current_size(), size1);
// Write some data
let data_ptr = init_struct.data_ptr.as_ptr();
write_range(data_ptr, 0xAA, 0..size1);
assert_range(data_ptr, 0xAA, 0..size1);
// Shrink
let size2 = 5000;
init_struct.set_size(size2).unwrap();
assert_eq!(init_struct.current_size(), size2);
// Grow again
let size3 = 20000;
init_struct.set_size(size3).unwrap();
assert_eq!(init_struct.current_size(), size3);
// Try to read it. The area that was shrunk and grown again should read as all zeros now
assert_range(data_ptr, 0xAA, 0..5000);
assert_range(data_ptr, 0, 5000..size1);
// Try to grow beyond max_size
//let size4 = max_size + 1;
//assert!(init_struct.set_size(size4).is_err());
// Dropping init_struct should unmap the memory
drop(init_struct);
Ok(())
}
/// This is used in tests to coordinate between test processes. It's like std::sync::Barrier,
/// but is stored in the shared memory area and works across processes. It's implemented by
/// polling, because e.g. standard rust mutexes are not guaranteed to work across processes.
struct SimpleBarrier {
num_procs: usize,
count: AtomicUsize,
}
impl SimpleBarrier {
unsafe fn init(ptr: *mut SimpleBarrier, num_procs: usize) {
unsafe {
*ptr = SimpleBarrier {
num_procs,
count: AtomicUsize::new(0),
}
}
}
pub fn wait(&self) {
let old = self.count.fetch_add(1, Ordering::Relaxed);
let generation = old / self.num_procs;
let mut current = old + 1;
while current < (generation + 1) * self.num_procs {
std::thread::sleep(std::time::Duration::from_millis(10));
current = self.count.load(Ordering::Relaxed);
}
}
}
#[test]
fn test_multi_process() {
// Initialize
let max_size = 1_000_000_000_000;
let init_struct = ShmemHandle::new("test_multi_process", 0, max_size).unwrap();
let ptr = init_struct.data_ptr.as_ptr();
// Store the SimpleBarrier in the first 1k of the area.
init_struct.set_size(10000).unwrap();
let barrier_ptr: *mut SimpleBarrier = unsafe {
ptr.add(ptr.align_offset(std::mem::align_of::<SimpleBarrier>()))
.cast()
};
unsafe { SimpleBarrier::init(barrier_ptr, 2) };
let barrier = unsafe { barrier_ptr.as_ref().unwrap() };
// Fork another test process. The code after this runs in both processes concurrently.
let fork_result = unsafe { nix::unistd::fork().unwrap() };
// In the parent, fill bytes between 1000..2000. In the child, between 2000..3000
if fork_result.is_parent() {
write_range(ptr, 0xAA, 1000..2000);
} else {
write_range(ptr, 0xBB, 2000..3000);
}
barrier.wait();
// Verify the contents. (in both processes)
assert_range(ptr, 0xAA, 1000..2000);
assert_range(ptr, 0xBB, 2000..3000);
// Grow, from the child this time
let size = 10_000_000;
if !fork_result.is_parent() {
init_struct.set_size(size).unwrap();
}
barrier.wait();
// make some writes at the end
if fork_result.is_parent() {
write_range(ptr, 0xAA, (size - 10)..size);
} else {
write_range(ptr, 0xBB, (size - 20)..(size - 10));
}
barrier.wait();
// Verify the contents. (This runs in both processes)
assert_range(ptr, 0, (size - 1000)..(size - 20));
assert_range(ptr, 0xBB, (size - 20)..(size - 10));
assert_range(ptr, 0xAA, (size - 10)..size);
if let ForkResult::Parent { child } = fork_result {
nix::sys::wait::waitpid(child, None).unwrap();
}
}
}

View File

@@ -20,6 +20,7 @@ use postgres_backend::AuthType;
use remote_storage::RemoteStorageConfig;
use serde_with::serde_as;
use utils::logging::LogFormat;
use utils::postgres_client::PostgresClientProtocol;
use crate::models::{ImageCompressionAlgorithm, LsnLease};
@@ -180,7 +181,6 @@ pub struct ConfigToml {
pub virtual_file_io_engine: Option<crate::models::virtual_file::IoEngineKind>,
pub ingest_batch_size: u64,
pub max_vectored_read_bytes: MaxVectoredReadBytes,
pub max_get_vectored_keys: MaxGetVectoredKeys,
pub image_compression: ImageCompressionAlgorithm,
pub timeline_offloading: bool,
pub ephemeral_bytes_per_memory_kb: usize,
@@ -188,6 +188,7 @@ pub struct ConfigToml {
pub virtual_file_io_mode: Option<crate::models::virtual_file::IoMode>,
#[serde(skip_serializing_if = "Option::is_none")]
pub no_sync: Option<bool>,
pub wal_receiver_protocol: PostgresClientProtocol,
pub page_service_pipelining: PageServicePipeliningConfig,
pub get_vectored_concurrent_io: GetVectoredConcurrentIo,
pub enable_read_path_debugging: Option<bool>,
@@ -228,7 +229,7 @@ pub enum PageServicePipeliningConfig {
}
#[derive(Debug, Clone, PartialEq, Eq, serde::Serialize, serde::Deserialize)]
pub struct PageServicePipeliningConfigPipelined {
/// Failed config parsing and validation if larger than `max_get_vectored_keys`.
/// Causes runtime errors if larger than max get_vectored batch size.
pub max_batch_size: NonZeroUsize,
pub execution: PageServiceProtocolPipelinedExecutionStrategy,
// The default below is such that new versions of the software can start
@@ -328,8 +329,6 @@ pub struct TimelineImportConfig {
pub import_job_concurrency: NonZeroUsize,
pub import_job_soft_size_limit: NonZeroUsize,
pub import_job_checkpoint_threshold: NonZeroUsize,
/// Max size of the remote storage partial read done by any job
pub import_job_max_byte_range_size: NonZeroUsize,
}
#[derive(Debug, Clone, PartialEq, Eq, serde::Serialize, serde::Deserialize)]
@@ -404,16 +403,6 @@ impl Default for EvictionOrder {
#[serde(transparent)]
pub struct MaxVectoredReadBytes(pub NonZeroUsize);
#[derive(Copy, Clone, Debug, PartialEq, Eq, serde::Serialize, serde::Deserialize)]
#[serde(transparent)]
pub struct MaxGetVectoredKeys(NonZeroUsize);
impl MaxGetVectoredKeys {
pub fn get(&self) -> usize {
self.0.get()
}
}
/// Tenant-level configuration values, used for various purposes.
#[derive(Debug, Clone, PartialEq, Eq, serde::Serialize, serde::Deserialize)]
#[serde(default)]
@@ -525,6 +514,8 @@ pub struct TenantConfigToml {
/// (either this flag or the pageserver-global one need to be set)
pub timeline_offloading: bool,
pub wal_receiver_protocol_override: Option<PostgresClientProtocol>,
/// Enable rel_size_v2 for this tenant. Once enabled, the tenant will persist this information into
/// `index_part.json`, and it cannot be reversed.
pub rel_size_v2_enabled: bool,
@@ -596,8 +587,6 @@ pub mod defaults {
/// That is, slightly above 128 kB.
pub const DEFAULT_MAX_VECTORED_READ_BYTES: usize = 130 * 1024; // 130 KiB
pub const DEFAULT_MAX_GET_VECTORED_KEYS: usize = 32;
pub const DEFAULT_IMAGE_COMPRESSION: ImageCompressionAlgorithm =
ImageCompressionAlgorithm::Zstd { level: Some(1) };
@@ -605,6 +594,9 @@ pub mod defaults {
pub const DEFAULT_IO_BUFFER_ALIGNMENT: usize = 512;
pub const DEFAULT_WAL_RECEIVER_PROTOCOL: utils::postgres_client::PostgresClientProtocol =
utils::postgres_client::PostgresClientProtocol::Vanilla;
pub const DEFAULT_SSL_KEY_FILE: &str = "server.key";
pub const DEFAULT_SSL_CERT_FILE: &str = "server.crt";
}
@@ -693,9 +685,6 @@ impl Default for ConfigToml {
max_vectored_read_bytes: (MaxVectoredReadBytes(
NonZeroUsize::new(DEFAULT_MAX_VECTORED_READ_BYTES).unwrap(),
)),
max_get_vectored_keys: (MaxGetVectoredKeys(
NonZeroUsize::new(DEFAULT_MAX_GET_VECTORED_KEYS).unwrap(),
)),
image_compression: (DEFAULT_IMAGE_COMPRESSION),
timeline_offloading: true,
ephemeral_bytes_per_memory_kb: (DEFAULT_EPHEMERAL_BYTES_PER_MEMORY_KB),
@@ -703,6 +692,7 @@ impl Default for ConfigToml {
virtual_file_io_mode: None,
tenant_config: TenantConfigToml::default(),
no_sync: None,
wal_receiver_protocol: DEFAULT_WAL_RECEIVER_PROTOCOL,
page_service_pipelining: PageServicePipeliningConfig::Pipelined(
PageServicePipeliningConfigPipelined {
max_batch_size: NonZeroUsize::new(32).unwrap(),
@@ -723,10 +713,9 @@ impl Default for ConfigToml {
enable_tls_page_service_api: false,
dev_mode: false,
timeline_import_config: TimelineImportConfig {
import_job_concurrency: NonZeroUsize::new(32).unwrap(),
import_job_soft_size_limit: NonZeroUsize::new(256 * 1024 * 1024).unwrap(),
import_job_checkpoint_threshold: NonZeroUsize::new(32).unwrap(),
import_job_max_byte_range_size: NonZeroUsize::new(4 * 1024 * 1024).unwrap(),
import_job_concurrency: NonZeroUsize::new(128).unwrap(),
import_job_soft_size_limit: NonZeroUsize::new(1024 * 1024 * 1024).unwrap(),
import_job_checkpoint_threshold: NonZeroUsize::new(128).unwrap(),
},
basebackup_cache_config: None,
posthog_config: None,
@@ -847,6 +836,7 @@ impl Default for TenantConfigToml {
lsn_lease_length: LsnLease::DEFAULT_LENGTH,
lsn_lease_length_for_ts: LsnLease::DEFAULT_LENGTH_FOR_TS,
timeline_offloading: true,
wal_receiver_protocol_override: None,
rel_size_v2_enabled: false,
gc_compaction_enabled: DEFAULT_GC_COMPACTION_ENABLED,
gc_compaction_verification: DEFAULT_GC_COMPACTION_VERIFICATION,

View File

@@ -20,6 +20,7 @@ use serde_with::serde_as;
pub use utilization::PageserverUtilization;
use utils::id::{NodeId, TenantId, TimelineId};
use utils::lsn::Lsn;
use utils::postgres_client::PostgresClientProtocol;
use utils::{completion, serde_system_time};
use crate::config::Ratio;
@@ -621,6 +622,8 @@ pub struct TenantConfigPatch {
#[serde(skip_serializing_if = "FieldPatch::is_noop")]
pub timeline_offloading: FieldPatch<bool>,
#[serde(skip_serializing_if = "FieldPatch::is_noop")]
pub wal_receiver_protocol_override: FieldPatch<PostgresClientProtocol>,
#[serde(skip_serializing_if = "FieldPatch::is_noop")]
pub rel_size_v2_enabled: FieldPatch<bool>,
#[serde(skip_serializing_if = "FieldPatch::is_noop")]
pub gc_compaction_enabled: FieldPatch<bool>,
@@ -745,6 +748,9 @@ pub struct TenantConfig {
#[serde(skip_serializing_if = "Option::is_none")]
pub timeline_offloading: Option<bool>,
#[serde(skip_serializing_if = "Option::is_none")]
pub wal_receiver_protocol_override: Option<PostgresClientProtocol>,
#[serde(skip_serializing_if = "Option::is_none")]
pub rel_size_v2_enabled: Option<bool>,
@@ -806,6 +812,7 @@ impl TenantConfig {
mut lsn_lease_length,
mut lsn_lease_length_for_ts,
mut timeline_offloading,
mut wal_receiver_protocol_override,
mut rel_size_v2_enabled,
mut gc_compaction_enabled,
mut gc_compaction_verification,
@@ -898,6 +905,9 @@ impl TenantConfig {
.map(|v| humantime::parse_duration(&v))?
.apply(&mut lsn_lease_length_for_ts);
patch.timeline_offloading.apply(&mut timeline_offloading);
patch
.wal_receiver_protocol_override
.apply(&mut wal_receiver_protocol_override);
patch.rel_size_v2_enabled.apply(&mut rel_size_v2_enabled);
patch
.gc_compaction_enabled
@@ -950,6 +960,7 @@ impl TenantConfig {
lsn_lease_length,
lsn_lease_length_for_ts,
timeline_offloading,
wal_receiver_protocol_override,
rel_size_v2_enabled,
gc_compaction_enabled,
gc_compaction_verification,
@@ -1047,6 +1058,9 @@ impl TenantConfig {
timeline_offloading: self
.timeline_offloading
.unwrap_or(global_conf.timeline_offloading),
wal_receiver_protocol_override: self
.wal_receiver_protocol_override
.or(global_conf.wal_receiver_protocol_override),
rel_size_v2_enabled: self
.rel_size_v2_enabled
.unwrap_or(global_conf.rel_size_v2_enabled),
@@ -1920,7 +1934,7 @@ pub enum PagestreamFeMessage {
}
// Wrapped in libpq CopyData
#[derive(Debug, strum_macros::EnumProperty)]
#[derive(strum_macros::EnumProperty)]
pub enum PagestreamBeMessage {
Exists(PagestreamExistsResponse),
Nblocks(PagestreamNblocksResponse),
@@ -2031,7 +2045,7 @@ pub enum PagestreamProtocolVersion {
pub type RequestId = u64;
#[derive(Debug, Default, PartialEq, Eq, Clone, Copy)]
#[derive(Debug, PartialEq, Eq, Clone, Copy)]
pub struct PagestreamRequest {
pub reqid: RequestId,
pub request_lsn: Lsn,
@@ -2050,7 +2064,7 @@ pub struct PagestreamNblocksRequest {
pub rel: RelTag,
}
#[derive(Debug, Default, PartialEq, Eq, Clone, Copy)]
#[derive(Debug, PartialEq, Eq, Clone, Copy)]
pub struct PagestreamGetPageRequest {
pub hdr: PagestreamRequest,
pub rel: RelTag,

View File

@@ -24,7 +24,7 @@ use serde::{Deserialize, Serialize};
// FIXME: should move 'forknum' as last field to keep this consistent with Postgres.
// Then we could replace the custom Ord and PartialOrd implementations below with
// deriving them. This will require changes in walredoproc.c.
#[derive(Debug, Default, PartialEq, Eq, Hash, Clone, Copy, Serialize, Deserialize)]
#[derive(Debug, PartialEq, Eq, Hash, Clone, Copy, Serialize, Deserialize)]
pub struct RelTag {
pub forknum: u8,
pub spcnode: Oid,
@@ -184,12 +184,12 @@ pub enum SlruKind {
MultiXactOffsets,
}
impl fmt::Display for SlruKind {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
impl SlruKind {
pub fn to_str(&self) -> &'static str {
match self {
Self::Clog => write!(f, "pg_xact"),
Self::MultiXactMembers => write!(f, "pg_multixact/members"),
Self::MultiXactOffsets => write!(f, "pg_multixact/offsets"),
Self::Clog => "pg_xact",
Self::MultiXactMembers => "pg_multixact/members",
Self::MultiXactOffsets => "pg_multixact/offsets",
}
}
}

View File

@@ -4,9 +4,8 @@ use std::{sync::Arc, time::Duration};
use arc_swap::ArcSwap;
use tokio_util::sync::CancellationToken;
use tracing::{Instrument, info_span};
use crate::{CaptureEvent, FeatureStore, PostHogClient, PostHogClientConfig};
use crate::{FeatureStore, PostHogClient, PostHogClientConfig};
/// A background loop that fetches feature flags from PostHog and updates the feature store.
pub struct FeatureResolverBackgroundLoop {
@@ -24,61 +23,34 @@ impl FeatureResolverBackgroundLoop {
}
}
pub fn spawn(
self: Arc<Self>,
handle: &tokio::runtime::Handle,
refresh_period: Duration,
fake_tenants: Vec<CaptureEvent>,
) {
pub fn spawn(self: Arc<Self>, handle: &tokio::runtime::Handle, refresh_period: Duration) {
let this = self.clone();
let cancel = self.cancel.clone();
// Main loop of updating the feature flags.
handle.spawn(
async move {
tracing::info!("Starting PostHog feature resolver");
let mut ticker = tokio::time::interval(refresh_period);
ticker.set_missed_tick_behavior(tokio::time::MissedTickBehavior::Skip);
loop {
tokio::select! {
_ = ticker.tick() => {}
_ = cancel.cancelled() => break
handle.spawn(async move {
tracing::info!("Starting PostHog feature resolver");
let mut ticker = tokio::time::interval(refresh_period);
ticker.set_missed_tick_behavior(tokio::time::MissedTickBehavior::Skip);
loop {
tokio::select! {
_ = ticker.tick() => {}
_ = cancel.cancelled() => break
}
let resp = match this
.posthog_client
.get_feature_flags_local_evaluation()
.await
{
Ok(resp) => resp,
Err(e) => {
tracing::warn!("Cannot get feature flags: {}", e);
continue;
}
let resp = match this
.posthog_client
.get_feature_flags_local_evaluation()
.await
{
Ok(resp) => resp,
Err(e) => {
tracing::warn!("Cannot get feature flags: {}", e);
continue;
}
};
let feature_store = FeatureStore::new_with_flags(resp.flags);
this.feature_store.store(Arc::new(feature_store));
tracing::info!("Feature flag updated");
}
tracing::info!("PostHog feature resolver stopped");
};
let feature_store = FeatureStore::new_with_flags(resp.flags);
this.feature_store.store(Arc::new(feature_store));
}
.instrument(info_span!("posthog_feature_resolver")),
);
// Report fake tenants to PostHog so that we have the combination of all the properties in the UI.
// Do one report per pageserver restart.
let this = self.clone();
handle.spawn(
async move {
tracing::info!("Starting PostHog feature reporter");
for tenant in &fake_tenants {
tracing::info!("Reporting fake tenant: {:?}", tenant);
}
if let Err(e) = this.posthog_client.capture_event_batch(&fake_tenants).await {
tracing::warn!("Cannot report fake tenants: {}", e);
}
}
.instrument(info_span!("posthog_feature_reporter")),
);
tracing::info!("PostHog feature resolver stopped");
});
}
pub fn feature_store(&self) -> Arc<FeatureStore> {

View File

@@ -22,16 +22,6 @@ pub enum PostHogEvaluationError {
Internal(String),
}
impl PostHogEvaluationError {
pub fn as_variant_str(&self) -> &'static str {
match self {
PostHogEvaluationError::NotAvailable(_) => "not_available",
PostHogEvaluationError::NoConditionGroupMatched => "no_condition_group_matched",
PostHogEvaluationError::Internal(_) => "internal",
}
}
}
#[derive(Deserialize)]
pub struct LocalEvaluationResponse {
pub flags: Vec<LocalEvaluationFlag>,
@@ -64,7 +54,7 @@ pub struct LocalEvaluationFlagFilterProperty {
operator: String,
}
#[derive(Debug, Serialize, Deserialize, Clone)]
#[derive(Debug, Serialize, Deserialize)]
#[serde(untagged)]
pub enum PostHogFlagFilterPropertyValue {
String(String),
@@ -458,18 +448,6 @@ impl FeatureStore {
)))
}
}
/// Infer whether a feature flag is a boolean flag by checking if it has a multivariate filter.
pub fn is_feature_flag_boolean(&self, flag_key: &str) -> Result<bool, PostHogEvaluationError> {
if let Some(flag_config) = self.flags.get(flag_key) {
Ok(flag_config.filters.multivariate.is_none())
} else {
Err(PostHogEvaluationError::NotAvailable(format!(
"Not found in the local evaluation spec: {}",
flag_key
)))
}
}
}
pub struct PostHogClientConfig {
@@ -507,13 +485,6 @@ pub struct PostHogClient {
client: reqwest::Client,
}
#[derive(Serialize, Debug)]
pub struct CaptureEvent {
pub event: String,
pub distinct_id: String,
pub properties: serde_json::Value,
}
impl PostHogClient {
pub fn new(config: PostHogClientConfig) -> Self {
let client = reqwest::Client::new();
@@ -557,15 +528,7 @@ impl PostHogClient {
.bearer_auth(&self.config.server_api_key)
.send()
.await?;
let status = response.status();
let body = response.text().await?;
if !status.is_success() {
return Err(anyhow::anyhow!(
"Failed to get feature flags: {}, {}",
status,
body
));
}
Ok(serde_json::from_str(&body)?)
}
@@ -577,12 +540,12 @@ impl PostHogClient {
&self,
event: &str,
distinct_id: &str,
properties: &serde_json::Value,
properties: &HashMap<String, PostHogFlagFilterPropertyValue>,
) -> anyhow::Result<()> {
// PUBLIC_URL/capture/
// with bearer token of self.client_api_key
let url = format!("{}/capture/", self.config.public_api_url);
let response = self
.client
self.client
.post(url)
.body(serde_json::to_string(&json!({
"api_key": self.config.client_api_key,
@@ -592,39 +555,6 @@ impl PostHogClient {
}))?)
.send()
.await?;
let status = response.status();
let body = response.text().await?;
if !status.is_success() {
return Err(anyhow::anyhow!(
"Failed to capture events: {}, {}",
status,
body
));
}
Ok(())
}
pub async fn capture_event_batch(&self, events: &[CaptureEvent]) -> anyhow::Result<()> {
// PUBLIC_URL/batch/
let url = format!("{}/batch/", self.config.public_api_url);
let response = self
.client
.post(url)
.body(serde_json::to_string(&json!({
"api_key": self.config.client_api_key,
"batch": events,
}))?)
.send()
.await?;
let status = response.status();
let body = response.text().await?;
if !status.is_success() {
return Err(anyhow::anyhow!(
"Failed to capture events: {}, {}",
status,
body
));
}
Ok(())
}
}

View File

@@ -28,7 +28,6 @@ use std::time::Duration;
use tokio::sync::Notify;
use tokio::time::Instant;
#[derive(Clone, Copy)]
pub struct LeakyBucketConfig {
/// This is the "time cost" of a single request unit.
/// Should loosely represent how long it takes to handle a request unit in active resource time.

View File

@@ -73,7 +73,6 @@ pub mod error;
/// async timeout helper
pub mod timeout;
pub mod span;
pub mod sync;
pub mod failpoint_support;

View File

@@ -1,19 +0,0 @@
//! Tracing span helpers.
/// Records the given fields in the current span, as a single call. The fields must already have
/// been declared for the span (typically with empty values).
#[macro_export]
macro_rules! span_record {
($($tokens:tt)*) => {$crate::span_record_in!(::tracing::Span::current(), $($tokens)*)};
}
/// Records the given fields in the given span, as a single call. The fields must already have been
/// declared for the span (typically with empty values).
#[macro_export]
macro_rules! span_record_in {
($span:expr, $($tokens:tt)*) => {
if let Some(meta) = $span.metadata() {
$span.record_all(&tracing::valueset!(meta.fields(), $($tokens)*));
}
};
}

View File

@@ -439,7 +439,6 @@ pub fn empty_shmem() -> crate::bindings::WalproposerShmemState {
currentClusterSize: crate::bindings::pg_atomic_uint64 { value: 0 },
shard_ps_feedback: [empty_feedback; 128],
num_shards: 0,
replica_promote: false,
min_ps_feedback: empty_feedback,
}
}

View File

@@ -34,7 +34,6 @@ fail.workspace = true
futures.workspace = true
hashlink.workspace = true
hex.workspace = true
http.workspace = true
http-utils.workspace = true
humantime-serde.workspace = true
humantime.workspace = true
@@ -94,7 +93,6 @@ tokio-util.workspace = true
toml_edit = { workspace = true, features = [ "serde" ] }
tonic.workspace = true
tonic-reflection.workspace = true
tower.workspace = true
tracing.workspace = true
tracing-utils.workspace = true
url.workspace = true

View File

@@ -264,56 +264,10 @@ mod propagation_of_cached_label_value {
}
}
criterion_group!(histograms, histograms::bench_bucket_scalability);
mod histograms {
use std::time::Instant;
use criterion::{BenchmarkId, Criterion};
use metrics::core::Collector;
pub fn bench_bucket_scalability(c: &mut Criterion) {
let mut g = c.benchmark_group("bucket_scalability");
for n in [1, 4, 8, 16, 32, 64, 128, 256] {
g.bench_with_input(BenchmarkId::new("nbuckets", n), &n, |b, n| {
b.iter_custom(|iters| {
let buckets: Vec<f64> = (0..*n).map(|i| i as f64 * 100.0).collect();
let histo = metrics::Histogram::with_opts(
metrics::prometheus::HistogramOpts::new("name", "help")
.buckets(buckets.clone()),
)
.unwrap();
let start = Instant::now();
for i in 0..usize::try_from(iters).unwrap() {
histo.observe(buckets[i % buckets.len()]);
}
let elapsed = start.elapsed();
// self-test
let mfs = histo.collect();
assert_eq!(mfs.len(), 1);
let metrics = mfs[0].get_metric();
assert_eq!(metrics.len(), 1);
let histo = metrics[0].get_histogram();
let buckets = histo.get_bucket();
assert!(
buckets
.iter()
.enumerate()
.all(|(i, b)| b.get_cumulative_count()
>= i as u64 * (iters / buckets.len() as u64))
);
elapsed
})
});
}
}
}
criterion_main!(
label_values,
single_metric_multicore_scalability,
propagation_of_cached_label_value,
histograms,
propagation_of_cached_label_value
);
/*
@@ -336,14 +290,6 @@ propagation_of_cached_label_value__naive/nthreads/8 time: [211.50 ns 214.44 ns
propagation_of_cached_label_value__long_lived_reference_per_thread/nthreads/1 time: [14.135 ns 14.147 ns 14.160 ns]
propagation_of_cached_label_value__long_lived_reference_per_thread/nthreads/4 time: [14.243 ns 14.255 ns 14.268 ns]
propagation_of_cached_label_value__long_lived_reference_per_thread/nthreads/8 time: [14.470 ns 14.682 ns 14.895 ns]
bucket_scalability/nbuckets/1 time: [30.352 ns 30.353 ns 30.354 ns]
bucket_scalability/nbuckets/4 time: [30.464 ns 30.465 ns 30.467 ns]
bucket_scalability/nbuckets/8 time: [30.569 ns 30.575 ns 30.584 ns]
bucket_scalability/nbuckets/16 time: [30.961 ns 30.965 ns 30.969 ns]
bucket_scalability/nbuckets/32 time: [35.691 ns 35.707 ns 35.722 ns]
bucket_scalability/nbuckets/64 time: [47.829 ns 47.898 ns 47.974 ns]
bucket_scalability/nbuckets/128 time: [73.479 ns 73.512 ns 73.545 ns]
bucket_scalability/nbuckets/256 time: [127.92 ns 127.94 ns 127.96 ns]
Results on an i3en.3xlarge instance
@@ -398,14 +344,6 @@ propagation_of_cached_label_value__naive/nthreads/8 time: [434.87 ns 456.4
propagation_of_cached_label_value__long_lived_reference_per_thread/nthreads/1 time: [3.3767 ns 3.3974 ns 3.4220 ns]
propagation_of_cached_label_value__long_lived_reference_per_thread/nthreads/4 time: [3.6105 ns 4.2355 ns 5.1463 ns]
propagation_of_cached_label_value__long_lived_reference_per_thread/nthreads/8 time: [4.0889 ns 4.9714 ns 6.0779 ns]
bucket_scalability/nbuckets/1 time: [4.8455 ns 4.8542 ns 4.8646 ns]
bucket_scalability/nbuckets/4 time: [4.5663 ns 4.5722 ns 4.5787 ns]
bucket_scalability/nbuckets/8 time: [4.5531 ns 4.5670 ns 4.5842 ns]
bucket_scalability/nbuckets/16 time: [4.6392 ns 4.6524 ns 4.6685 ns]
bucket_scalability/nbuckets/32 time: [6.0302 ns 6.0439 ns 6.0589 ns]
bucket_scalability/nbuckets/64 time: [10.608 ns 10.644 ns 10.691 ns]
bucket_scalability/nbuckets/128 time: [22.178 ns 22.316 ns 22.483 ns]
bucket_scalability/nbuckets/256 time: [42.190 ns 42.328 ns 42.492 ns]
Results on a Hetzner AX102 AMD Ryzen 9 7950X3D 16-Core Processor
@@ -424,13 +362,5 @@ propagation_of_cached_label_value__naive/nthreads/8 time: [164.24 ns 170.1
propagation_of_cached_label_value__long_lived_reference_per_thread/nthreads/1 time: [2.2915 ns 2.2960 ns 2.3012 ns]
propagation_of_cached_label_value__long_lived_reference_per_thread/nthreads/4 time: [2.5726 ns 2.6158 ns 2.6624 ns]
propagation_of_cached_label_value__long_lived_reference_per_thread/nthreads/8 time: [2.7068 ns 2.8243 ns 2.9824 ns]
bucket_scalability/nbuckets/1 time: [6.3998 ns 6.4288 ns 6.4684 ns]
bucket_scalability/nbuckets/4 time: [6.3603 ns 6.3620 ns 6.3637 ns]
bucket_scalability/nbuckets/8 time: [6.1646 ns 6.1654 ns 6.1667 ns]
bucket_scalability/nbuckets/16 time: [6.1341 ns 6.1391 ns 6.1454 ns]
bucket_scalability/nbuckets/32 time: [8.2206 ns 8.2254 ns 8.2301 ns]
bucket_scalability/nbuckets/64 time: [13.988 ns 13.994 ns 14.000 ns]
bucket_scalability/nbuckets/128 time: [28.180 ns 28.216 ns 28.251 ns]
bucket_scalability/nbuckets/256 time: [54.914 ns 54.931 ns 54.951 ns]
*/

View File

@@ -9,6 +9,7 @@ bytes.workspace = true
pageserver_api.workspace = true
postgres_ffi.workspace = true
prost.workspace = true
smallvec.workspace = true
thiserror.workspace = true
tonic.workspace = true
utils.workspace = true

View File

@@ -9,16 +9,10 @@
//! - Use more precise datatypes, e.g. Lsn and uints shorter than 32 bits.
//!
//! - Validate protocol invariants, via try_from() and try_into().
//!
//! Validation only happens on the receiver side, i.e. when converting from Protobuf to domain
//! types. This is where it matters -- the Protobuf types are less strict than the domain types, and
//! receivers should expect all sorts of junk from senders. This also allows the sender to use e.g.
//! stream combinators without dealing with errors, and avoids validating the same message twice.
use std::fmt::Display;
use bytes::Bytes;
use postgres_ffi::Oid;
use smallvec::SmallVec;
// TODO: split out Lsn, RelTag, SlruKind, Oid and other basic types to a separate crate, to avoid
// pulling in all of their other crate dependencies when building the client.
use utils::lsn::Lsn;
@@ -54,8 +48,7 @@ pub struct ReadLsn {
pub request_lsn: Lsn,
/// If given, the caller guarantees that the page has not been modified since this LSN. Must be
/// smaller than or equal to request_lsn. This allows the Pageserver to serve an old page
/// without waiting for the request LSN to arrive. If not given, the request will read at the
/// request_lsn and wait for it to arrive if necessary. Valid for all request types.
/// without waiting for the request LSN to arrive. Valid for all request types.
///
/// It is undefined behaviour to make a request such that the page was, in fact, modified
/// between request_lsn and not_modified_since_lsn. The Pageserver might detect it and return an
@@ -65,14 +58,19 @@ pub struct ReadLsn {
pub not_modified_since_lsn: Option<Lsn>,
}
impl Display for ReadLsn {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
let req_lsn = self.request_lsn;
if let Some(mod_lsn) = self.not_modified_since_lsn {
write!(f, "{req_lsn}>={mod_lsn}")
} else {
req_lsn.fmt(f)
impl ReadLsn {
/// Validates the ReadLsn.
pub fn validate(&self) -> Result<(), ProtocolError> {
if self.request_lsn == Lsn::INVALID {
return Err(ProtocolError::invalid("request_lsn", self.request_lsn));
}
if self.not_modified_since_lsn > Some(self.request_lsn) {
return Err(ProtocolError::invalid(
"not_modified_since_lsn",
self.not_modified_since_lsn,
));
}
Ok(())
}
}
@@ -80,31 +78,27 @@ impl TryFrom<proto::ReadLsn> for ReadLsn {
type Error = ProtocolError;
fn try_from(pb: proto::ReadLsn) -> Result<Self, Self::Error> {
if pb.request_lsn == 0 {
return Err(ProtocolError::invalid("request_lsn", pb.request_lsn));
}
if pb.not_modified_since_lsn > pb.request_lsn {
return Err(ProtocolError::invalid(
"not_modified_since_lsn",
pb.not_modified_since_lsn,
));
}
Ok(Self {
let read_lsn = Self {
request_lsn: Lsn(pb.request_lsn),
not_modified_since_lsn: match pb.not_modified_since_lsn {
0 => None,
lsn => Some(Lsn(lsn)),
},
})
};
read_lsn.validate()?;
Ok(read_lsn)
}
}
impl From<ReadLsn> for proto::ReadLsn {
fn from(read_lsn: ReadLsn) -> Self {
Self {
impl TryFrom<ReadLsn> for proto::ReadLsn {
type Error = ProtocolError;
fn try_from(read_lsn: ReadLsn) -> Result<Self, Self::Error> {
read_lsn.validate()?;
Ok(Self {
request_lsn: read_lsn.request_lsn.0,
not_modified_since_lsn: read_lsn.not_modified_since_lsn.unwrap_or_default().0,
}
})
}
}
@@ -159,15 +153,6 @@ impl TryFrom<proto::CheckRelExistsRequest> for CheckRelExistsRequest {
}
}
impl From<CheckRelExistsRequest> for proto::CheckRelExistsRequest {
fn from(request: CheckRelExistsRequest) -> Self {
Self {
read_lsn: Some(request.read_lsn.into()),
rel: Some(request.rel.into()),
}
}
}
pub type CheckRelExistsResponse = bool;
impl From<proto::CheckRelExistsResponse> for CheckRelExistsResponse {
@@ -205,12 +190,14 @@ impl TryFrom<proto::GetBaseBackupRequest> for GetBaseBackupRequest {
}
}
impl From<GetBaseBackupRequest> for proto::GetBaseBackupRequest {
fn from(request: GetBaseBackupRequest) -> Self {
Self {
read_lsn: Some(request.read_lsn.into()),
impl TryFrom<GetBaseBackupRequest> for proto::GetBaseBackupRequest {
type Error = ProtocolError;
fn try_from(request: GetBaseBackupRequest) -> Result<Self, Self::Error> {
Ok(Self {
read_lsn: Some(request.read_lsn.try_into()?),
replica: request.replica,
}
})
}
}
@@ -227,9 +214,14 @@ impl TryFrom<proto::GetBaseBackupResponseChunk> for GetBaseBackupResponseChunk {
}
}
impl From<GetBaseBackupResponseChunk> for proto::GetBaseBackupResponseChunk {
fn from(chunk: GetBaseBackupResponseChunk) -> Self {
Self { chunk }
impl TryFrom<GetBaseBackupResponseChunk> for proto::GetBaseBackupResponseChunk {
type Error = ProtocolError;
fn try_from(chunk: GetBaseBackupResponseChunk) -> Result<Self, Self::Error> {
if chunk.is_empty() {
return Err(ProtocolError::Missing("chunk"));
}
Ok(Self { chunk })
}
}
@@ -254,12 +246,14 @@ impl TryFrom<proto::GetDbSizeRequest> for GetDbSizeRequest {
}
}
impl From<GetDbSizeRequest> for proto::GetDbSizeRequest {
fn from(request: GetDbSizeRequest) -> Self {
Self {
read_lsn: Some(request.read_lsn.into()),
impl TryFrom<GetDbSizeRequest> for proto::GetDbSizeRequest {
type Error = ProtocolError;
fn try_from(request: GetDbSizeRequest) -> Result<Self, Self::Error> {
Ok(Self {
read_lsn: Some(request.read_lsn.try_into()?),
db_oid: request.db_oid,
}
})
}
}
@@ -294,7 +288,7 @@ pub struct GetPageRequest {
/// Multiple pages will be executed as a single batch by the Pageserver, amortizing layer access
/// costs and parallelizing them. This may increase the latency of any individual request, but
/// improves the overall latency and throughput of the batch as a whole.
pub block_numbers: Vec<u32>,
pub block_numbers: SmallVec<[u32; 1]>,
}
impl TryFrom<proto::GetPageRequest> for GetPageRequest {
@@ -312,20 +306,25 @@ impl TryFrom<proto::GetPageRequest> for GetPageRequest {
.ok_or(ProtocolError::Missing("read_lsn"))?
.try_into()?,
rel: pb.rel.ok_or(ProtocolError::Missing("rel"))?.try_into()?,
block_numbers: pb.block_number,
block_numbers: pb.block_number.into(),
})
}
}
impl From<GetPageRequest> for proto::GetPageRequest {
fn from(request: GetPageRequest) -> Self {
Self {
impl TryFrom<GetPageRequest> for proto::GetPageRequest {
type Error = ProtocolError;
fn try_from(request: GetPageRequest) -> Result<Self, Self::Error> {
if request.block_numbers.is_empty() {
return Err(ProtocolError::Missing("block_number"));
}
Ok(Self {
request_id: request.request_id,
request_class: request.request_class.into(),
read_lsn: Some(request.read_lsn.into()),
read_lsn: Some(request.read_lsn.try_into()?),
rel: Some(request.rel.into()),
block_number: request.block_numbers,
}
block_number: request.block_numbers.into_vec(),
})
}
}
@@ -397,7 +396,7 @@ pub struct GetPageResponse {
/// A string describing the status, if any.
pub reason: Option<String>,
/// The 8KB page images, in the same order as the request. Empty if status != OK.
pub page_images: Vec<Bytes>,
pub page_images: SmallVec<[Bytes; 1]>,
}
impl From<proto::GetPageResponse> for GetPageResponse {
@@ -406,7 +405,7 @@ impl From<proto::GetPageResponse> for GetPageResponse {
request_id: pb.request_id,
status_code: pb.status_code.into(),
reason: Some(pb.reason).filter(|r| !r.is_empty()),
page_images: pb.page_image,
page_images: pb.page_image.into(),
}
}
}
@@ -417,7 +416,7 @@ impl From<GetPageResponse> for proto::GetPageResponse {
request_id: response.request_id,
status_code: response.status_code.into(),
reason: response.reason.unwrap_or_default(),
page_image: response.page_images,
page_image: response.page_images.into_vec(),
}
}
}
@@ -506,12 +505,14 @@ impl TryFrom<proto::GetRelSizeRequest> for GetRelSizeRequest {
}
}
impl From<GetRelSizeRequest> for proto::GetRelSizeRequest {
fn from(request: GetRelSizeRequest) -> Self {
Self {
read_lsn: Some(request.read_lsn.into()),
impl TryFrom<GetRelSizeRequest> for proto::GetRelSizeRequest {
type Error = ProtocolError;
fn try_from(request: GetRelSizeRequest) -> Result<Self, Self::Error> {
Ok(Self {
read_lsn: Some(request.read_lsn.try_into()?),
rel: Some(request.rel.into()),
}
})
}
}
@@ -554,13 +555,15 @@ impl TryFrom<proto::GetSlruSegmentRequest> for GetSlruSegmentRequest {
}
}
impl From<GetSlruSegmentRequest> for proto::GetSlruSegmentRequest {
fn from(request: GetSlruSegmentRequest) -> Self {
Self {
read_lsn: Some(request.read_lsn.into()),
impl TryFrom<GetSlruSegmentRequest> for proto::GetSlruSegmentRequest {
type Error = ProtocolError;
fn try_from(request: GetSlruSegmentRequest) -> Result<Self, Self::Error> {
Ok(Self {
read_lsn: Some(request.read_lsn.try_into()?),
kind: request.kind as u32,
segno: request.segno,
}
})
}
}
@@ -577,9 +580,14 @@ impl TryFrom<proto::GetSlruSegmentResponse> for GetSlruSegmentResponse {
}
}
impl From<GetSlruSegmentResponse> for proto::GetSlruSegmentResponse {
fn from(segment: GetSlruSegmentResponse) -> Self {
Self { segment }
impl TryFrom<GetSlruSegmentResponse> for proto::GetSlruSegmentResponse {
type Error = ProtocolError;
fn try_from(segment: GetSlruSegmentResponse) -> Result<Self, Self::Error> {
if segment.is_empty() {
return Err(ProtocolError::Missing("segment"));
}
Ok(Self { segment })
}
}

View File

@@ -8,8 +8,6 @@ license.workspace = true
[dependencies]
anyhow.workspace = true
async-trait.workspace = true
bytes.workspace = true
camino.workspace = true
clap.workspace = true
futures.workspace = true
@@ -17,17 +15,14 @@ hdrhistogram.workspace = true
humantime.workspace = true
humantime-serde.workspace = true
rand.workspace = true
reqwest.workspace = true
reqwest.workspace=true
serde.workspace = true
serde_json.workspace = true
tracing.workspace = true
tokio.workspace = true
tokio-stream.workspace = true
tokio-util.workspace = true
tonic.workspace = true
pageserver_client.workspace = true
pageserver_api.workspace = true
pageserver_page_api.workspace = true
utils = { path = "../../libs/utils/" }
workspace_hack = { version = "0.1", path = "../../workspace_hack" }

View File

@@ -1,4 +1,4 @@
use std::collections::{HashMap, HashSet, VecDeque};
use std::collections::{HashSet, VecDeque};
use std::future::Future;
use std::num::NonZeroUsize;
use std::pin::Pin;
@@ -7,15 +7,11 @@ use std::sync::{Arc, Mutex};
use std::time::{Duration, Instant};
use anyhow::Context;
use async_trait::async_trait;
use bytes::Bytes;
use camino::Utf8PathBuf;
use pageserver_api::key::Key;
use pageserver_api::keyspace::KeySpaceAccum;
use pageserver_api::models::{PagestreamGetPageRequest, PagestreamRequest};
use pageserver_api::reltag::RelTag;
use pageserver_api::shard::TenantShardId;
use pageserver_page_api::proto;
use rand::prelude::*;
use tokio::task::JoinSet;
use tokio_util::sync::CancellationToken;
@@ -26,12 +22,6 @@ use utils::lsn::Lsn;
use crate::util::tokio_thread_local_stats::AllThreadLocalStats;
use crate::util::{request_stats, tokio_thread_local_stats};
#[derive(clap::ValueEnum, Clone, Debug)]
enum Protocol {
Libpq,
Grpc,
}
/// GetPage@LatestLSN, uniformly distributed across the compute-accessible keyspace.
#[derive(clap::Parser)]
pub(crate) struct Args {
@@ -45,8 +35,6 @@ pub(crate) struct Args {
num_clients: NonZeroUsize,
#[clap(long)]
runtime: Option<humantime::Duration>,
#[clap(long, value_enum, default_value = "libpq")]
protocol: Protocol,
/// Each client sends requests at the given rate.
///
/// If a request takes too long and we should be issuing a new request already,
@@ -77,16 +65,6 @@ pub(crate) struct Args {
#[clap(long, default_value = "1")]
queue_depth: NonZeroUsize,
/// Batch size of contiguous pages generated by each client. This is equivalent to how Postgres
/// will request page batches (e.g. prefetches or vectored reads). A batch counts as 1 RPS and
/// 1 queue depth.
///
/// The libpq protocol does not support client-side batching, and will submit batches as many
/// individual requests, in the hope that the server will batch them. Each batch still counts as
/// 1 RPS and 1 queue depth.
#[clap(long, default_value = "1")]
batch_size: NonZeroUsize,
#[clap(long)]
only_relnode: Option<u32>,
@@ -325,20 +303,7 @@ async fn main_impl(
.unwrap();
Box::pin(async move {
let client: Box<dyn Client> = match args.protocol {
Protocol::Libpq => Box::new(
LibpqClient::new(args.page_service_connstring.clone(), worker_id.timeline)
.await
.unwrap(),
),
Protocol::Grpc => Box::new(
GrpcClient::new(args.page_service_connstring.clone(), worker_id.timeline)
.await
.unwrap(),
),
};
run_worker(args, client, ss, cancel, rps_period, ranges, weights).await
client_libpq(args, worker_id, ss, cancel, rps_period, ranges, weights).await
})
};
@@ -390,28 +355,27 @@ async fn main_impl(
anyhow::Ok(())
}
async fn run_worker(
async fn client_libpq(
args: &Args,
mut client: Box<dyn Client>,
worker_id: WorkerId,
shared_state: Arc<SharedState>,
cancel: CancellationToken,
rps_period: Option<Duration>,
ranges: Vec<KeyRange>,
weights: rand::distributions::weighted::WeightedIndex<i128>,
) {
let client = pageserver_client::page_service::Client::new(args.page_service_connstring.clone())
.await
.unwrap();
let mut client = client
.pagestream(worker_id.timeline.tenant_id, worker_id.timeline.timeline_id)
.await
.unwrap();
shared_state.start_work_barrier.wait().await;
let client_start = Instant::now();
let mut ticks_processed = 0;
let mut req_id = 0;
let batch_size: usize = args.batch_size.into();
// Track inflight requests by request ID and start time. This times the request duration, and
// ensures responses match requests. We don't expect responses back in any particular order.
//
// NB: this does not check that all requests received a response, because we don't wait for the
// inflight requests to complete when the duration elapses.
let mut inflight: HashMap<u64, Instant> = HashMap::new();
let mut inflight = VecDeque::new();
while !cancel.is_cancelled() {
// Detect if a request took longer than the RPS rate
if let Some(period) = &rps_period {
@@ -427,72 +391,36 @@ async fn run_worker(
}
while inflight.len() < args.queue_depth.get() {
req_id += 1;
let start = Instant::now();
let (req_lsn, mod_lsn, rel, blks) = {
/// Converts a compact i128 key to a relation tag and block number.
fn key_to_block(key: i128) -> (RelTag, u32) {
let key = Key::from_i128(key);
assert!(key.is_rel_block_key());
key.to_rel_block()
.expect("we filter non-rel-block keys out above")
}
// Pick a random page from a random relation.
let req = {
let mut rng = rand::thread_rng();
let r = &ranges[weights.sample(&mut rng)];
let key: i128 = rng.gen_range(r.start..r.end);
let (rel_tag, block_no) = key_to_block(key);
let mut blks = VecDeque::with_capacity(batch_size);
blks.push_back(block_no);
// If requested, populate a batch of sequential pages. This is how Postgres will
// request page batches (e.g. prefetches). If we hit the end of the relation, we
// grow the batch towards the start too.
for i in 1..batch_size {
let (r, b) = key_to_block(key + i as i128);
if r != rel_tag {
break; // went outside relation
}
blks.push_back(b)
let key = Key::from_i128(key);
assert!(key.is_rel_block_key());
let (rel_tag, block_no) = key
.to_rel_block()
.expect("we filter non-rel-block keys out above");
PagestreamGetPageRequest {
hdr: PagestreamRequest {
reqid: 0,
request_lsn: if rng.gen_bool(args.req_latest_probability) {
Lsn::MAX
} else {
r.timeline_lsn
},
not_modified_since: r.timeline_lsn,
},
rel: rel_tag,
blkno: block_no,
}
if blks.len() < batch_size {
// Grow batch backwards if needed.
for i in 1..batch_size {
let (r, b) = key_to_block(key - i as i128);
if r != rel_tag {
break; // went outside relation
}
blks.push_front(b)
}
}
// We assume that the entire batch can fit within the relation.
assert_eq!(blks.len(), batch_size, "incomplete batch");
let req_lsn = if rng.gen_bool(args.req_latest_probability) {
Lsn::MAX
} else {
r.timeline_lsn
};
(req_lsn, r.timeline_lsn, rel_tag, blks.into())
};
client
.send_get_page(req_id, req_lsn, mod_lsn, rel, blks)
.await
.unwrap();
let old = inflight.insert(req_id, start);
assert!(old.is_none(), "duplicate request ID {req_id}");
client.getpage_send(req).await.unwrap();
inflight.push_back(start);
}
let (req_id, pages) = client.recv_get_page().await.unwrap();
assert_eq!(pages.len(), batch_size, "unexpected page count");
assert!(pages.iter().all(|p| !p.is_empty()), "empty page");
let start = inflight
.remove(&req_id)
.expect("response for unknown request ID");
let start = inflight.pop_front().unwrap();
client.getpage_recv().await.unwrap();
let end = Instant::now();
shared_state.live_stats.request_done();
ticks_processed += 1;
@@ -514,154 +442,3 @@ async fn run_worker(
}
}
}
/// A benchmark client, to allow switching out the transport protocol.
///
/// For simplicity, this just uses separate asynchronous send/recv methods. The send method could
/// return a future that resolves when the response is received, but we don't really need it.
#[async_trait]
trait Client: Send {
/// Sends an asynchronous GetPage request to the pageserver.
async fn send_get_page(
&mut self,
req_id: u64,
req_lsn: Lsn,
mod_lsn: Lsn,
rel: RelTag,
blks: Vec<u32>,
) -> anyhow::Result<()>;
/// Receives the next GetPage response from the pageserver.
async fn recv_get_page(&mut self) -> anyhow::Result<(u64, Vec<Bytes>)>;
}
/// A libpq-based Pageserver client.
struct LibpqClient {
inner: pageserver_client::page_service::PagestreamClient,
// Track sent batches, so we know how many responses to expect.
batch_sizes: VecDeque<usize>,
}
impl LibpqClient {
async fn new(connstring: String, ttid: TenantTimelineId) -> anyhow::Result<Self> {
let inner = pageserver_client::page_service::Client::new(connstring)
.await?
.pagestream(ttid.tenant_id, ttid.timeline_id)
.await?;
Ok(Self {
inner,
batch_sizes: VecDeque::new(),
})
}
}
#[async_trait]
impl Client for LibpqClient {
async fn send_get_page(
&mut self,
req_id: u64,
req_lsn: Lsn,
mod_lsn: Lsn,
rel: RelTag,
blks: Vec<u32>,
) -> anyhow::Result<()> {
// libpq doesn't support client-side batches, so we send a bunch of individual requests
// instead in the hope that the server will batch them for us. We use the same request ID
// for all, because we'll return a single batch response.
self.batch_sizes.push_back(blks.len());
for blkno in blks {
let req = PagestreamGetPageRequest {
hdr: PagestreamRequest {
reqid: req_id,
request_lsn: req_lsn,
not_modified_since: mod_lsn,
},
rel,
blkno,
};
self.inner.getpage_send(req).await?;
}
Ok(())
}
async fn recv_get_page(&mut self) -> anyhow::Result<(u64, Vec<Bytes>)> {
let batch_size = self.batch_sizes.pop_front().unwrap();
let mut batch = Vec::with_capacity(batch_size);
let mut req_id = None;
for _ in 0..batch_size {
let resp = self.inner.getpage_recv().await?;
if req_id.is_none() {
req_id = Some(resp.req.hdr.reqid);
}
assert_eq!(req_id, Some(resp.req.hdr.reqid), "request ID mismatch");
batch.push(resp.page);
}
Ok((req_id.unwrap(), batch))
}
}
/// A gRPC client using the raw, no-frills gRPC client.
struct GrpcClient {
req_tx: tokio::sync::mpsc::Sender<proto::GetPageRequest>,
resp_rx: tonic::Streaming<proto::GetPageResponse>,
}
impl GrpcClient {
async fn new(connstring: String, ttid: TenantTimelineId) -> anyhow::Result<Self> {
let mut client = pageserver_page_api::proto::PageServiceClient::connect(connstring).await?;
// The channel has a buffer size of 1, since 0 is not allowed. It does not matter, since the
// benchmark will control the queue depth (i.e. in-flight requests) anyway, and requests are
// buffered by Tonic and the OS too.
let (req_tx, req_rx) = tokio::sync::mpsc::channel(1);
let req_stream = tokio_stream::wrappers::ReceiverStream::new(req_rx);
let mut req = tonic::Request::new(req_stream);
let metadata = req.metadata_mut();
metadata.insert("neon-tenant-id", ttid.tenant_id.to_string().try_into()?);
metadata.insert("neon-timeline-id", ttid.timeline_id.to_string().try_into()?);
metadata.insert("neon-shard-id", "0000".try_into()?);
let resp = client.get_pages(req).await?;
let resp_stream = resp.into_inner();
Ok(Self {
req_tx,
resp_rx: resp_stream,
})
}
}
#[async_trait]
impl Client for GrpcClient {
async fn send_get_page(
&mut self,
req_id: u64,
req_lsn: Lsn,
mod_lsn: Lsn,
rel: RelTag,
blks: Vec<u32>,
) -> anyhow::Result<()> {
let req = proto::GetPageRequest {
request_id: req_id,
request_class: proto::GetPageClass::Normal as i32,
read_lsn: Some(proto::ReadLsn {
request_lsn: req_lsn.0,
not_modified_since_lsn: mod_lsn.0,
}),
rel: Some(rel.into()),
block_number: blks,
};
self.req_tx.send(req).await?;
Ok(())
}
async fn recv_get_page(&mut self) -> anyhow::Result<(u64, Vec<Bytes>)> {
let resp = self.resp_rx.message().await?.unwrap();
anyhow::ensure!(
resp.status_code == proto::GetPageStatusCode::Ok as i32,
"unexpected status code: {}",
resp.status_code
);
Ok((resp.request_id, resp.page_image))
}
}

View File

@@ -65,30 +65,6 @@ impl From<GetVectoredError> for BasebackupError {
}
}
impl From<BasebackupError> for postgres_backend::QueryError {
fn from(err: BasebackupError) -> Self {
use postgres_backend::QueryError;
use pq_proto::framed::ConnectionError;
match err {
BasebackupError::Client(err, _) => QueryError::Disconnected(ConnectionError::Io(err)),
BasebackupError::Server(err) => QueryError::Other(err),
BasebackupError::Shutdown => QueryError::Shutdown,
}
}
}
impl From<BasebackupError> for tonic::Status {
fn from(err: BasebackupError) -> Self {
use tonic::Code;
let code = match &err {
BasebackupError::Client(_, _) => Code::Cancelled,
BasebackupError::Server(_) => Code::Internal,
BasebackupError::Shutdown => Code::Unavailable,
};
tonic::Status::new(code, err.to_string())
}
}
/// Create basebackup with non-rel data in it.
/// Only include relational data if 'full_backup' is true.
///
@@ -272,7 +248,7 @@ where
async fn flush(&mut self) -> Result<(), BasebackupError> {
let nblocks = self.buf.len() / BLCKSZ as usize;
let (kind, segno) = self.current_segment.take().unwrap();
let segname = format!("{kind}/{segno:>04X}");
let segname = format!("{}/{:>04X}", kind.to_str(), segno);
let header = new_tar_header(&segname, self.buf.len() as u64)?;
self.ar
.append(&header, self.buf.as_slice())
@@ -371,7 +347,7 @@ where
.await?
.partition(
self.timeline.get_shard_identity(),
self.timeline.conf.max_get_vectored_keys.get() as u64 * BLCKSZ as u64,
Timeline::MAX_GET_VECTORED_KEYS * BLCKSZ as u64,
);
let mut slru_builder = SlruSegmentsBuilder::new(&mut self.ar);

View File

@@ -158,6 +158,7 @@ fn main() -> anyhow::Result<()> {
// (maybe we should automate this with a visitor?).
info!(?conf.virtual_file_io_engine, "starting with virtual_file IO engine");
info!(?conf.virtual_file_io_mode, "starting with virtual_file IO mode");
info!(?conf.wal_receiver_protocol, "starting with WAL receiver protocol");
info!(?conf.validate_wal_contiguity, "starting with WAL contiguity validation");
info!(?conf.page_service_pipelining, "starting with page service pipelining config");
info!(?conf.get_vectored_concurrent_io, "starting with get_vectored IO concurrency config");
@@ -803,7 +804,7 @@ fn start_pageserver(
} else {
None
},
basebackup_cache,
basebackup_cache.clone(),
);
// Spawn a Pageserver gRPC server task. It will spawn separate tasks for
@@ -815,11 +816,12 @@ fn start_pageserver(
let mut page_service_grpc = None;
if let Some(grpc_listener) = grpc_listener {
page_service_grpc = Some(page_service::spawn_grpc(
conf,
tenant_manager.clone(),
grpc_auth,
otel_guard.as_ref().map(|g| g.dispatch.clone()),
conf.get_vectored_concurrent_io,
grpc_listener,
basebackup_cache,
)?);
}

View File

@@ -14,10 +14,7 @@ use std::time::Duration;
use anyhow::{Context, bail, ensure};
use camino::{Utf8Path, Utf8PathBuf};
use once_cell::sync::OnceCell;
use pageserver_api::config::{
DiskUsageEvictionTaskConfig, MaxGetVectoredKeys, MaxVectoredReadBytes,
PageServicePipeliningConfig, PageServicePipeliningConfigPipelined, PostHogConfig,
};
use pageserver_api::config::{DiskUsageEvictionTaskConfig, MaxVectoredReadBytes, PostHogConfig};
use pageserver_api::models::ImageCompressionAlgorithm;
use pageserver_api::shard::TenantShardId;
use pem::Pem;
@@ -27,6 +24,7 @@ use reqwest::Url;
use storage_broker::Uri;
use utils::id::{NodeId, TimelineId};
use utils::logging::{LogFormat, SecretString};
use utils::postgres_client::PostgresClientProtocol;
use crate::tenant::storage_layer::inmemory_layer::IndexEntry;
use crate::tenant::{TENANTS_SEGMENT_NAME, TIMELINES_SEGMENT_NAME};
@@ -187,9 +185,6 @@ pub struct PageServerConf {
pub max_vectored_read_bytes: MaxVectoredReadBytes,
/// Maximum number of keys to be read in a single get_vectored call.
pub max_get_vectored_keys: MaxGetVectoredKeys,
pub image_compression: ImageCompressionAlgorithm,
/// Whether to offload archived timelines automatically
@@ -210,6 +205,8 @@ pub struct PageServerConf {
/// Optionally disable disk syncs (unsafe!)
pub no_sync: bool,
pub wal_receiver_protocol: PostgresClientProtocol,
pub page_service_pipelining: pageserver_api::config::PageServicePipeliningConfig,
pub get_vectored_concurrent_io: pageserver_api::config::GetVectoredConcurrentIo,
@@ -407,7 +404,6 @@ impl PageServerConf {
secondary_download_concurrency,
ingest_batch_size,
max_vectored_read_bytes,
max_get_vectored_keys,
image_compression,
timeline_offloading,
ephemeral_bytes_per_memory_kb,
@@ -418,6 +414,7 @@ impl PageServerConf {
virtual_file_io_engine,
tenant_config,
no_sync,
wal_receiver_protocol,
page_service_pipelining,
get_vectored_concurrent_io,
enable_read_path_debugging,
@@ -473,13 +470,13 @@ impl PageServerConf {
secondary_download_concurrency,
ingest_batch_size,
max_vectored_read_bytes,
max_get_vectored_keys,
image_compression,
timeline_offloading,
ephemeral_bytes_per_memory_kb,
import_pgdata_upcall_api,
import_pgdata_upcall_api_token: import_pgdata_upcall_api_token.map(SecretString::from),
import_pgdata_aws_endpoint_url,
wal_receiver_protocol,
page_service_pipelining,
get_vectored_concurrent_io,
tracing,
@@ -601,19 +598,6 @@ impl PageServerConf {
)
})?;
if let PageServicePipeliningConfig::Pipelined(PageServicePipeliningConfigPipelined {
max_batch_size,
..
}) = conf.page_service_pipelining
{
if max_batch_size.get() > conf.max_get_vectored_keys.get() {
return Err(anyhow::anyhow!(
"`max_batch_size` ({max_batch_size}) must be less than or equal to `max_get_vectored_keys` ({})",
conf.max_get_vectored_keys.get()
));
}
};
Ok(conf)
}
@@ -701,7 +685,6 @@ impl ConfigurableSemaphore {
mod tests {
use camino::Utf8PathBuf;
use rstest::rstest;
use utils::id::NodeId;
use super::PageServerConf;
@@ -741,28 +724,4 @@ mod tests {
PageServerConf::parse_and_validate(NodeId(0), config_toml, &workdir)
.expect_err("parse_and_validate should fail for endpoint without scheme");
}
#[rstest]
#[case(32, 32, true)]
#[case(64, 32, false)]
#[case(64, 64, true)]
#[case(128, 128, true)]
fn test_config_max_batch_size_is_valid(
#[case] max_batch_size: usize,
#[case] max_get_vectored_keys: usize,
#[case] is_valid: bool,
) {
let input = format!(
r#"
control_plane_api = "http://localhost:6666"
max_get_vectored_keys = {max_get_vectored_keys}
page_service_pipelining = {{ mode="pipelined", execution="concurrent-futures", max_batch_size={max_batch_size}, batching="uniform-lsn" }}
"#,
);
let config_toml = toml_edit::de::from_str::<pageserver_api::config::ConfigToml>(&input)
.expect("config has valid fields");
let workdir = Utf8PathBuf::from("/nonexistent");
let result = PageServerConf::parse_and_validate(NodeId(0), config_toml, &workdir);
assert_eq!(result.is_ok(), is_valid);
}
}

View File

@@ -1,28 +1,21 @@
use std::{collections::HashMap, sync::Arc, time::Duration};
use posthog_client_lite::{
CaptureEvent, FeatureResolverBackgroundLoop, PostHogClientConfig, PostHogEvaluationError,
PostHogFlagFilterPropertyValue,
FeatureResolverBackgroundLoop, PostHogClientConfig, PostHogEvaluationError,
};
use remote_storage::RemoteStorageKind;
use serde_json::json;
use tokio_util::sync::CancellationToken;
use utils::id::TenantId;
use crate::{config::PageServerConf, metrics::FEATURE_FLAG_EVALUATION};
use crate::config::PageServerConf;
#[derive(Clone)]
pub struct FeatureResolver {
inner: Option<Arc<FeatureResolverBackgroundLoop>>,
internal_properties: Option<Arc<HashMap<String, PostHogFlagFilterPropertyValue>>>,
}
impl FeatureResolver {
pub fn new_disabled() -> Self {
Self {
inner: None,
internal_properties: None,
}
Self { inner: None }
}
pub fn spawn(
@@ -43,114 +36,14 @@ impl FeatureResolver {
shutdown_pageserver,
);
let inner = Arc::new(inner);
// The properties shared by all tenants on this pageserver.
let internal_properties = {
let mut properties = HashMap::new();
properties.insert(
"pageserver_id".to_string(),
PostHogFlagFilterPropertyValue::String(conf.id.to_string()),
);
if let Some(availability_zone) = &conf.availability_zone {
properties.insert(
"availability_zone".to_string(),
PostHogFlagFilterPropertyValue::String(availability_zone.clone()),
);
}
// Infer region based on the remote storage config.
if let Some(remote_storage) = &conf.remote_storage_config {
match &remote_storage.storage {
RemoteStorageKind::AwsS3(config) => {
properties.insert(
"region".to_string(),
PostHogFlagFilterPropertyValue::String(format!(
"aws-{}",
config.bucket_region
)),
);
}
RemoteStorageKind::AzureContainer(config) => {
properties.insert(
"region".to_string(),
PostHogFlagFilterPropertyValue::String(format!(
"azure-{}",
config.container_region
)),
);
}
RemoteStorageKind::LocalFs { .. } => {
properties.insert(
"region".to_string(),
PostHogFlagFilterPropertyValue::String("local".to_string()),
);
}
}
}
// TODO: add pageserver URL.
Arc::new(properties)
};
let fake_tenants = {
let mut tenants = Vec::new();
for i in 0..10 {
let distinct_id = format!(
"fake_tenant_{}_{}_{}",
conf.availability_zone.as_deref().unwrap_or_default(),
conf.id,
i
);
let properties = Self::collect_properties_inner(
distinct_id.clone(),
Some(&internal_properties),
);
tenants.push(CaptureEvent {
event: "initial_tenant_report".to_string(),
distinct_id,
properties: json!({ "$set": properties }), // use `$set` to set the person properties instead of the event properties
});
}
tenants
};
// TODO: make refresh period configurable
inner
.clone()
.spawn(handle, Duration::from_secs(60), fake_tenants);
Ok(FeatureResolver {
inner: Some(inner),
internal_properties: Some(internal_properties),
})
// TODO: make this configurable
inner.clone().spawn(handle, Duration::from_secs(60));
Ok(FeatureResolver { inner: Some(inner) })
} else {
Ok(FeatureResolver {
inner: None,
internal_properties: None,
})
Ok(FeatureResolver { inner: None })
}
}
fn collect_properties_inner(
tenant_id: String,
internal_properties: Option<&HashMap<String, PostHogFlagFilterPropertyValue>>,
) -> HashMap<String, PostHogFlagFilterPropertyValue> {
let mut properties = HashMap::new();
if let Some(internal_properties) = internal_properties {
for (key, value) in internal_properties.iter() {
properties.insert(key.clone(), value.clone());
}
}
properties.insert(
"tenant_id".to_string(),
PostHogFlagFilterPropertyValue::String(tenant_id),
);
properties
}
/// Collect all properties availble for the feature flag evaluation.
pub(crate) fn collect_properties(
&self,
tenant_id: TenantId,
) -> HashMap<String, PostHogFlagFilterPropertyValue> {
Self::collect_properties_inner(tenant_id.to_string(), self.internal_properties.as_deref())
}
/// Evaluate a multivariate feature flag. Currently, we do not support any properties.
///
/// Error handling: the caller should inspect the error and decide the behavior when a feature flag
@@ -162,24 +55,11 @@ impl FeatureResolver {
tenant_id: TenantId,
) -> Result<String, PostHogEvaluationError> {
if let Some(inner) = &self.inner {
let res = inner.feature_store().evaluate_multivariate(
inner.feature_store().evaluate_multivariate(
flag_key,
&tenant_id.to_string(),
&self.collect_properties(tenant_id),
);
match &res {
Ok(value) => {
FEATURE_FLAG_EVALUATION
.with_label_values(&[flag_key, "ok", value])
.inc();
}
Err(e) => {
FEATURE_FLAG_EVALUATION
.with_label_values(&[flag_key, "error", e.as_variant_str()])
.inc();
}
}
res
&HashMap::new(),
)
} else {
Err(PostHogEvaluationError::NotAvailable(
"PostHog integration is not enabled".to_string(),
@@ -200,34 +80,11 @@ impl FeatureResolver {
tenant_id: TenantId,
) -> Result<(), PostHogEvaluationError> {
if let Some(inner) = &self.inner {
let res = inner.feature_store().evaluate_boolean(
inner.feature_store().evaluate_boolean(
flag_key,
&tenant_id.to_string(),
&self.collect_properties(tenant_id),
);
match &res {
Ok(()) => {
FEATURE_FLAG_EVALUATION
.with_label_values(&[flag_key, "ok", "true"])
.inc();
}
Err(e) => {
FEATURE_FLAG_EVALUATION
.with_label_values(&[flag_key, "error", e.as_variant_str()])
.inc();
}
}
res
} else {
Err(PostHogEvaluationError::NotAvailable(
"PostHog integration is not enabled".to_string(),
))
}
}
pub fn is_feature_flag_boolean(&self, flag_key: &str) -> Result<bool, PostHogEvaluationError> {
if let Some(inner) = &self.inner {
inner.feature_store().is_feature_flag_boolean(flag_key)
&HashMap::new(),
)
} else {
Err(PostHogEvaluationError::NotAvailable(
"PostHog integration is not enabled".to_string(),

View File

@@ -43,7 +43,6 @@ use pageserver_api::models::{
use pageserver_api::shard::{ShardCount, TenantShardId};
use remote_storage::{DownloadError, GenericRemoteStorage, TimeTravelError};
use scopeguard::defer;
use serde_json::json;
use tenant_size_model::svg::SvgBranchKind;
use tenant_size_model::{SizeResult, StorageModel};
use tokio::time::Instant;
@@ -3664,47 +3663,6 @@ async fn read_tar_eof(mut reader: (impl tokio::io::AsyncRead + Unpin)) -> anyhow
Ok(())
}
async fn tenant_evaluate_feature_flag(
request: Request<Body>,
_cancel: CancellationToken,
) -> Result<Response<Body>, ApiError> {
let tenant_shard_id: TenantShardId = parse_request_param(&request, "tenant_shard_id")?;
check_permission(&request, Some(tenant_shard_id.tenant_id))?;
let flag: String = must_parse_query_param(&request, "flag")?;
let as_type: String = must_parse_query_param(&request, "as")?;
let state = get_state(&request);
async {
let tenant = state
.tenant_manager
.get_attached_tenant_shard(tenant_shard_id)?;
let properties = tenant.feature_resolver.collect_properties(tenant_shard_id.tenant_id);
if as_type == "boolean" {
let result = tenant.feature_resolver.evaluate_boolean(&flag, tenant_shard_id.tenant_id);
let result = result.map(|_| true).map_err(|e| e.to_string());
json_response(StatusCode::OK, json!({ "result": result, "properties": properties }))
} else if as_type == "multivariate" {
let result = tenant.feature_resolver.evaluate_multivariate(&flag, tenant_shard_id.tenant_id).map_err(|e| e.to_string());
json_response(StatusCode::OK, json!({ "result": result, "properties": properties }))
} else {
// Auto infer the type of the feature flag.
let is_boolean = tenant.feature_resolver.is_feature_flag_boolean(&flag).map_err(|e| ApiError::InternalServerError(anyhow::anyhow!("{e}")))?;
if is_boolean {
let result = tenant.feature_resolver.evaluate_boolean(&flag, tenant_shard_id.tenant_id);
let result = result.map(|_| true).map_err(|e| e.to_string());
json_response(StatusCode::OK, json!({ "result": result, "properties": properties }))
} else {
let result = tenant.feature_resolver.evaluate_multivariate(&flag, tenant_shard_id.tenant_id).map_err(|e| e.to_string());
json_response(StatusCode::OK, json!({ "result": result, "properties": properties }))
}
}
}
.instrument(info_span!("tenant_evaluate_feature_flag", tenant_id = %tenant_shard_id.tenant_id, shard_id = %tenant_shard_id.shard_slug()))
.await
}
/// Common functionality of all the HTTP API handlers.
///
/// - Adds a tracing span to each request (by `request_span`)
@@ -4081,8 +4039,5 @@ pub fn make_router(
"/v1/tenant/:tenant_shard_id/timeline/:timeline_id/activate_post_import",
|r| api_handler(r, activate_post_import_handler),
)
.get("/v1/tenant/:tenant_shard_id/feature_flag", |r| {
api_handler(r, tenant_evaluate_feature_flag)
})
.any(handler_404))
}

View File

@@ -15,7 +15,6 @@ use metrics::{
register_int_gauge, register_int_gauge_vec, register_uint_gauge, register_uint_gauge_vec,
};
use once_cell::sync::Lazy;
use pageserver_api::config::defaults::DEFAULT_MAX_GET_VECTORED_KEYS;
use pageserver_api::config::{
PageServicePipeliningConfig, PageServicePipeliningConfigPipelined,
PageServiceProtocolPipelinedBatchingStrategy, PageServiceProtocolPipelinedExecutionStrategy,
@@ -33,6 +32,7 @@ use crate::config::PageServerConf;
use crate::context::{PageContentKind, RequestContext};
use crate::pgdatadir_mapping::DatadirModificationStats;
use crate::task_mgr::TaskKind;
use crate::tenant::Timeline;
use crate::tenant::layer_map::LayerMap;
use crate::tenant::mgr::TenantSlot;
use crate::tenant::storage_layer::{InMemoryLayer, PersistentLayerDesc};
@@ -446,15 +446,6 @@ static PAGE_CACHE_ERRORS: Lazy<IntCounterVec> = Lazy::new(|| {
.expect("failed to define a metric")
});
pub(crate) static FEATURE_FLAG_EVALUATION: Lazy<CounterVec> = Lazy::new(|| {
register_counter_vec!(
"pageserver_feature_flag_evaluation",
"Number of times a feature flag is evaluated",
&["flag_key", "status", "value"],
)
.unwrap()
});
#[derive(IntoStaticStr)]
#[strum(serialize_all = "kebab_case")]
pub(crate) enum PageCacheErrorKind {
@@ -1321,44 +1312,11 @@ impl EvictionsWithLowResidenceDuration {
//
// Roughly logarithmic scale.
const STORAGE_IO_TIME_BUCKETS: &[f64] = &[
0.00005, // 50us
0.00006, // 60us
0.00007, // 70us
0.00008, // 80us
0.00009, // 90us
0.0001, // 100us
0.000110, // 110us
0.000120, // 120us
0.000130, // 130us
0.000140, // 140us
0.000150, // 150us
0.000160, // 160us
0.000170, // 170us
0.000180, // 180us
0.000190, // 190us
0.000200, // 200us
0.000210, // 210us
0.000220, // 220us
0.000230, // 230us
0.000240, // 240us
0.000250, // 250us
0.000300, // 300us
0.000350, // 350us
0.000400, // 400us
0.000450, // 450us
0.000500, // 500us
0.000600, // 600us
0.000700, // 700us
0.000800, // 800us
0.000900, // 900us
0.001000, // 1ms
0.002000, // 2ms
0.003000, // 3ms
0.004000, // 4ms
0.005000, // 5ms
0.01000, // 10ms
0.02000, // 20ms
0.05000, // 50ms
0.000030, // 30 usec
0.001000, // 1000 usec
0.030, // 30 ms
1.000, // 1000 ms
30.000, // 30000 ms
];
/// VirtualFile fs operation variants.
@@ -1948,7 +1906,7 @@ static SMGR_QUERY_TIME_GLOBAL: Lazy<HistogramVec> = Lazy::new(|| {
});
static PAGE_SERVICE_BATCH_SIZE_BUCKETS_GLOBAL: Lazy<Vec<f64>> = Lazy::new(|| {
(1..=u32::try_from(DEFAULT_MAX_GET_VECTORED_KEYS).unwrap())
(1..=u32::try_from(Timeline::MAX_GET_VECTORED_KEYS).unwrap())
.map(|v| v.into())
.collect()
});
@@ -1966,7 +1924,7 @@ static PAGE_SERVICE_BATCH_SIZE_BUCKETS_PER_TIMELINE: Lazy<Vec<f64>> = Lazy::new(
let mut buckets = Vec::new();
for i in 0.. {
let bucket = 1 << i;
if bucket > u32::try_from(DEFAULT_MAX_GET_VECTORED_KEYS).unwrap() {
if bucket > u32::try_from(Timeline::MAX_GET_VECTORED_KEYS).unwrap() {
break;
}
buckets.push(bucket.into());
@@ -2855,6 +2813,7 @@ pub(crate) struct WalIngestMetrics {
pub(crate) records_received: IntCounter,
pub(crate) records_observed: IntCounter,
pub(crate) records_committed: IntCounter,
pub(crate) records_filtered: IntCounter,
pub(crate) values_committed_metadata_images: IntCounter,
pub(crate) values_committed_metadata_deltas: IntCounter,
pub(crate) values_committed_data_images: IntCounter,
@@ -2910,6 +2869,11 @@ pub(crate) static WAL_INGEST: Lazy<WalIngestMetrics> = Lazy::new(|| {
"Number of WAL records which resulted in writes to pageserver storage"
)
.expect("failed to define a metric"),
records_filtered: register_int_counter!(
"pageserver_wal_ingest_records_filtered",
"Number of WAL records filtered out due to sharding"
)
.expect("failed to define a metric"),
values_committed_metadata_images: values_committed.with_label_values(&["metadata", "image"]),
values_committed_metadata_deltas: values_committed.with_label_values(&["metadata", "delta"]),
values_committed_data_images: values_committed.with_label_values(&["data", "image"]),

File diff suppressed because it is too large Load Diff

View File

@@ -431,10 +431,10 @@ impl Timeline {
GetVectoredError::InvalidLsn(e) => {
Err(anyhow::anyhow!("invalid LSN: {e:?}").into())
}
// NB: this should never happen in practice because we limit batch size to be smaller than max_get_vectored_keys
// NB: this should never happen in practice because we limit MAX_GET_VECTORED_KEYS
// TODO: we can prevent this error class by moving this check into the type system
GetVectoredError::Oversized(err, max) => {
Err(anyhow::anyhow!("batching oversized: {err} > {max}").into())
GetVectoredError::Oversized(err) => {
Err(anyhow::anyhow!("batching oversized: {err:?}").into())
}
};
@@ -471,19 +471,8 @@ impl Timeline {
let rels = self.list_rels(spcnode, dbnode, version, ctx).await?;
if rels.is_empty() {
return Ok(0);
}
// Pre-deserialize the rel directory to avoid duplicated work in `get_relsize_cached`.
let reldir_key = rel_dir_to_key(spcnode, dbnode);
let buf = version.get(self, reldir_key, ctx).await?;
let reldir = RelDirectory::des(&buf)?;
for rel in rels {
let n_blocks = self
.get_rel_size_in_reldir(rel, version, Some((reldir_key, &reldir)), ctx)
.await?;
let n_blocks = self.get_rel_size(rel, version, ctx).await?;
total_blocks += n_blocks as usize;
}
Ok(total_blocks)
@@ -498,19 +487,6 @@ impl Timeline {
tag: RelTag,
version: Version<'_>,
ctx: &RequestContext,
) -> Result<BlockNumber, PageReconstructError> {
self.get_rel_size_in_reldir(tag, version, None, ctx).await
}
/// Get size of a relation file. The relation must exist, otherwise an error is returned.
///
/// See [`Self::get_rel_exists_in_reldir`] on why we need `deserialized_reldir_v1`.
pub(crate) async fn get_rel_size_in_reldir(
&self,
tag: RelTag,
version: Version<'_>,
deserialized_reldir_v1: Option<(Key, &RelDirectory)>,
ctx: &RequestContext,
) -> Result<BlockNumber, PageReconstructError> {
if tag.relnode == 0 {
return Err(PageReconstructError::Other(
@@ -523,9 +499,7 @@ impl Timeline {
}
if (tag.forknum == FSM_FORKNUM || tag.forknum == VISIBILITYMAP_FORKNUM)
&& !self
.get_rel_exists_in_reldir(tag, version, deserialized_reldir_v1, ctx)
.await?
&& !self.get_rel_exists(tag, version, ctx).await?
{
// FIXME: Postgres sometimes calls smgrcreate() to create
// FSM, and smgrnblocks() on it immediately afterwards,
@@ -547,28 +521,11 @@ impl Timeline {
///
/// Only shard 0 has a full view of the relations. Other shards only know about relations that
/// the shard stores pages for.
///
pub(crate) async fn get_rel_exists(
&self,
tag: RelTag,
version: Version<'_>,
ctx: &RequestContext,
) -> Result<bool, PageReconstructError> {
self.get_rel_exists_in_reldir(tag, version, None, ctx).await
}
/// Does the relation exist? With a cached deserialized `RelDirectory`.
///
/// There are some cases where the caller loops across all relations. In that specific case,
/// the caller should obtain the deserialized `RelDirectory` first and then call this function
/// to avoid duplicated work of deserliazation. This is a hack and should be removed by introducing
/// a new API (e.g., `get_rel_exists_batched`).
pub(crate) async fn get_rel_exists_in_reldir(
&self,
tag: RelTag,
version: Version<'_>,
deserialized_reldir_v1: Option<(Key, &RelDirectory)>,
ctx: &RequestContext,
) -> Result<bool, PageReconstructError> {
if tag.relnode == 0 {
return Err(PageReconstructError::Other(
@@ -611,17 +568,6 @@ impl Timeline {
// fetch directory listing (old)
let key = rel_dir_to_key(tag.spcnode, tag.dbnode);
if let Some((cached_key, dir)) = deserialized_reldir_v1 {
if cached_key == key {
return Ok(dir.rels.contains(&(tag.relnode, tag.forknum)));
} else if cfg!(test) || cfg!(feature = "testing") {
panic!("cached reldir key mismatch: {cached_key} != {key}");
} else {
warn!("cached reldir key mismatch: {cached_key} != {key}");
}
// Fallback to reading the directory from the datadir.
}
let buf = version.get(self, key, ctx).await?;
let dir = RelDirectory::des(&buf)?;
@@ -719,7 +665,7 @@ impl Timeline {
let batches = keyspace.partition(
self.get_shard_identity(),
self.conf.max_get_vectored_keys.get() as u64 * BLCKSZ as u64,
Timeline::MAX_GET_VECTORED_KEYS * BLCKSZ as u64,
);
let io_concurrency = IoConcurrency::spawn_from_conf(
@@ -959,7 +905,7 @@ impl Timeline {
let batches = keyspace.partition(
self.get_shard_identity(),
self.conf.max_get_vectored_keys.get() as u64 * BLCKSZ as u64,
Timeline::MAX_GET_VECTORED_KEYS * BLCKSZ as u64,
);
let io_concurrency = IoConcurrency::spawn_from_conf(

View File

@@ -383,7 +383,7 @@ pub struct TenantShard {
l0_flush_global_state: L0FlushGlobalState,
pub(crate) feature_resolver: FeatureResolver,
feature_resolver: FeatureResolver,
}
impl std::fmt::Debug for TenantShard {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
@@ -5832,7 +5832,6 @@ pub(crate) mod harness {
pub conf: &'static PageServerConf,
pub tenant_conf: pageserver_api::models::TenantConfig,
pub tenant_shard_id: TenantShardId,
pub shard_identity: ShardIdentity,
pub generation: Generation,
pub shard: ShardIndex,
pub remote_storage: GenericRemoteStorage,
@@ -5900,7 +5899,6 @@ pub(crate) mod harness {
conf,
tenant_conf,
tenant_shard_id,
shard_identity,
generation,
shard,
remote_storage,
@@ -5962,7 +5960,8 @@ pub(crate) mod harness {
&ShardParameters::default(),
))
.unwrap(),
self.shard_identity,
// This is a legacy/test code path: sharding isn't supported here.
ShardIdentity::unsharded(),
Some(walredo_mgr),
self.tenant_shard_id,
self.remote_storage.clone(),
@@ -6084,7 +6083,6 @@ mod tests {
use timeline::compaction::{KeyHistoryRetention, KeyLogAtLsn};
use timeline::{CompactOptions, DeltaLayerTestDesc, VersionedKeySpaceQuery};
use utils::id::TenantId;
use utils::shard::{ShardCount, ShardNumber};
use super::*;
use crate::DEFAULT_PG_VERSION;
@@ -7197,7 +7195,7 @@ mod tests {
let end = desc
.key_range
.start
.add(tenant.conf.max_get_vectored_keys.get() as u32);
.add(Timeline::MAX_GET_VECTORED_KEYS.try_into().unwrap());
reads.push(KeySpace {
ranges: vec![start..end],
});
@@ -9420,77 +9418,6 @@ mod tests {
Ok(())
}
#[tokio::test]
async fn test_failed_flush_should_not_update_disk_consistent_lsn() -> anyhow::Result<()> {
//
// Setup
//
let harness = TenantHarness::create_custom(
"test_failed_flush_should_not_upload_disk_consistent_lsn",
pageserver_api::models::TenantConfig::default(),
TenantId::generate(),
ShardIdentity::new(ShardNumber(0), ShardCount(4), ShardStripeSize(128)).unwrap(),
Generation::new(1),
)
.await?;
let (tenant, ctx) = harness.load().await;
let timeline = tenant
.create_test_timeline(TIMELINE_ID, Lsn(0x10), DEFAULT_PG_VERSION, &ctx)
.await?;
assert_eq!(timeline.get_shard_identity().count, ShardCount(4));
let mut writer = timeline.writer().await;
writer
.put(
*TEST_KEY,
Lsn(0x20),
&Value::Image(test_img("foo at 0x20")),
&ctx,
)
.await?;
writer.finish_write(Lsn(0x20));
drop(writer);
timeline.freeze_and_flush().await.unwrap();
timeline.remote_client.wait_completion().await.unwrap();
let disk_consistent_lsn = timeline.get_disk_consistent_lsn();
let remote_consistent_lsn = timeline.get_remote_consistent_lsn_projected();
assert_eq!(Some(disk_consistent_lsn), remote_consistent_lsn);
//
// Test
//
let mut writer = timeline.writer().await;
writer
.put(
*TEST_KEY,
Lsn(0x30),
&Value::Image(test_img("foo at 0x30")),
&ctx,
)
.await?;
writer.finish_write(Lsn(0x30));
drop(writer);
fail::cfg(
"flush-layer-before-update-remote-consistent-lsn",
"return()",
)
.unwrap();
let flush_res = timeline.freeze_and_flush().await;
// if flush failed, the disk/remote consistent LSN should not be updated
assert!(flush_res.is_err());
assert_eq!(disk_consistent_lsn, timeline.get_disk_consistent_lsn());
assert_eq!(
remote_consistent_lsn,
timeline.get_remote_consistent_lsn_projected()
);
Ok(())
}
#[cfg(feature = "testing")]
#[tokio::test]
async fn test_simple_bottom_most_compaction_deltas_1() -> anyhow::Result<()> {
@@ -11260,11 +11187,11 @@ mod tests {
let mut keyspaces_at_lsn: HashMap<Lsn, KeySpaceRandomAccum> = HashMap::default();
let mut used_keys: HashSet<Key> = HashSet::default();
while used_keys.len() < tenant.conf.max_get_vectored_keys.get() {
while used_keys.len() < Timeline::MAX_GET_VECTORED_KEYS as usize {
let selected_lsn = interesting_lsns.choose(&mut random).expect("not empty");
let mut selected_key = start_key.add(random.gen_range(0..KEY_DIMENSION_SIZE));
while used_keys.len() < tenant.conf.max_get_vectored_keys.get() {
while used_keys.len() < Timeline::MAX_GET_VECTORED_KEYS as usize {
if used_keys.contains(&selected_key)
|| selected_key >= start_key.add(KEY_DIMENSION_SIZE)
{

View File

@@ -817,8 +817,8 @@ pub(crate) enum GetVectoredError {
#[error("timeline shutting down")]
Cancelled,
#[error("requested too many keys: {0} > {1}")]
Oversized(u64, u64),
#[error("requested too many keys: {0} > {}", Timeline::MAX_GET_VECTORED_KEYS)]
Oversized(u64),
#[error("requested at invalid LSN: {0}")]
InvalidLsn(Lsn),
@@ -950,18 +950,6 @@ pub(crate) enum WaitLsnError {
Timeout(String),
}
impl From<WaitLsnError> for tonic::Status {
fn from(err: WaitLsnError) -> Self {
use tonic::Code;
let code = match &err {
WaitLsnError::Timeout(_) => Code::Internal,
WaitLsnError::BadState(_) => Code::Internal,
WaitLsnError::Shutdown => Code::Unavailable,
};
tonic::Status::new(code, err.to_string())
}
}
// The impls below achieve cancellation mapping for errors.
// Perhaps there's a way of achieving this with less cruft.
@@ -1019,7 +1007,7 @@ impl From<GetVectoredError> for PageReconstructError {
match e {
GetVectoredError::Cancelled => PageReconstructError::Cancelled,
GetVectoredError::InvalidLsn(_) => PageReconstructError::Other(anyhow!("Invalid LSN")),
err @ GetVectoredError::Oversized(_, _) => PageReconstructError::Other(err.into()),
err @ GetVectoredError::Oversized(_) => PageReconstructError::Other(err.into()),
GetVectoredError::MissingKey(err) => PageReconstructError::MissingKey(err),
GetVectoredError::GetReadyAncestorError(err) => PageReconstructError::from(err),
GetVectoredError::Other(err) => PageReconstructError::Other(err),
@@ -1199,6 +1187,7 @@ impl Timeline {
}
}
pub(crate) const MAX_GET_VECTORED_KEYS: u64 = 32;
pub(crate) const LAYERS_VISITED_WARN_THRESHOLD: u32 = 100;
/// Look up multiple page versions at a given LSN
@@ -1213,12 +1202,9 @@ impl Timeline {
) -> Result<BTreeMap<Key, Result<Bytes, PageReconstructError>>, GetVectoredError> {
let total_keyspace = query.total_keyspace();
let key_count = total_keyspace.total_raw_size();
if key_count > self.conf.max_get_vectored_keys.get() {
return Err(GetVectoredError::Oversized(
key_count as u64,
self.conf.max_get_vectored_keys.get() as u64,
));
let key_count = total_keyspace.total_raw_size().try_into().unwrap();
if key_count > Timeline::MAX_GET_VECTORED_KEYS {
return Err(GetVectoredError::Oversized(key_count));
}
for range in &total_keyspace.ranges {
@@ -2506,13 +2492,6 @@ impl Timeline {
// Preparing basebackup doesn't make sense for shards other than shard zero.
return;
}
if !self.is_active() {
// May happen during initial timeline creation.
// Such timeline is not in the global timeline map yet,
// so basebackup cache will not be able to find it.
// TODO(diko): We can prepare such timelines in finish_creation().
return;
}
let res = self
.basebackup_prepare_sender
@@ -2852,6 +2831,21 @@ impl Timeline {
)
}
/// Resolve the effective WAL receiver protocol to use for this tenant.
///
/// Priority order is:
/// 1. Tenant config override
/// 2. Default value for tenant config override
/// 3. Pageserver config override
/// 4. Pageserver config default
pub fn resolve_wal_receiver_protocol(&self) -> PostgresClientProtocol {
let tenant_conf = self.tenant_conf.load().tenant_conf.clone();
tenant_conf
.wal_receiver_protocol_override
.or(self.conf.default_tenant_conf.wal_receiver_protocol_override)
.unwrap_or(self.conf.wal_receiver_protocol)
}
pub(super) fn tenant_conf_updated(&self, new_conf: &AttachedTenantConf) {
// NB: Most tenant conf options are read by background loops, so,
// changes will automatically be picked up.
@@ -3207,16 +3201,10 @@ impl Timeline {
guard.is_none(),
"multiple launches / re-launches of WAL receiver are not supported"
);
let protocol = PostgresClientProtocol::Interpreted {
format: utils::postgres_client::InterpretedFormat::Protobuf,
compression: Some(utils::postgres_client::Compression::Zstd { level: 1 }),
};
*guard = Some(WalReceiver::start(
Arc::clone(self),
WalReceiverConf {
protocol,
protocol: self.resolve_wal_receiver_protocol(),
wal_connect_timeout,
lagging_wal_timeout,
max_lsn_wal_lag,
@@ -4779,10 +4767,7 @@ impl Timeline {
|| !flushed_to_lsn.is_valid()
);
if flushed_to_lsn < frozen_to_lsn
&& self.shard_identity.count.count() > 1
&& result.is_ok()
{
if flushed_to_lsn < frozen_to_lsn && self.shard_identity.count.count() > 1 {
// If our layer flushes didn't carry disk_consistent_lsn up to the `to_lsn` advertised
// to us via layer_flush_start_rx, then advance it here.
//
@@ -4961,10 +4946,6 @@ impl Timeline {
return Err(FlushLayerError::Cancelled);
}
fail_point!("flush-layer-before-update-remote-consistent-lsn", |_| {
Err(FlushLayerError::Other(anyhow!("failpoint").into()))
});
let disk_consistent_lsn = Lsn(lsn_range.end.0 - 1);
// The new on-disk layers are now in the layer map. We can remove the
@@ -5270,7 +5251,7 @@ impl Timeline {
key = key.next();
// Maybe flush `key_rest_accum`
if key_request_accum.raw_size() >= self.conf.max_get_vectored_keys.get() as u64
if key_request_accum.raw_size() >= Timeline::MAX_GET_VECTORED_KEYS
|| (last_key_in_range && key_request_accum.raw_size() > 0)
{
let query =

View File

@@ -106,8 +106,6 @@ pub async fn doit(
);
}
tracing::info!("Import plan executed. Flushing remote changes and notifying storcon");
timeline
.remote_client
.schedule_index_upload_for_file_changes()?;
@@ -201,8 +199,8 @@ async fn prepare_import(
.await;
match res {
Ok(_) => break,
Err(_err) => {
info!("indefinitely waiting for pgdata to finish");
Err(err) => {
info!(?err, "indefinitely waiting for pgdata to finish");
if tokio::time::timeout(std::time::Duration::from_secs(10), cancel.cancelled())
.await
.is_ok()

View File

@@ -11,7 +11,19 @@
//! - => S3 as the source for the PGDATA instead of local filesystem
//!
//! TODOs before productionization:
//! - ChunkProcessingJob size / ImportJob::total_size does not account for sharding.
//! => produced image layers likely too small.
//! - ChunkProcessingJob should cut up an ImportJob to hit exactly target image layer size.
//! - asserts / unwraps need to be replaced with errors
//! - don't trust remote objects will be small (=prevent OOMs in those cases)
//! - limit all in-memory buffers in size, or download to disk and read from there
//! - limit task concurrency
//! - generally play nice with other tenants in the system
//! - importbucket is different bucket than main pageserver storage, so, should be fine wrt S3 rate limits
//! - but concerns like network bandwidth, local disk write bandwidth, local disk capacity, etc
//! - integrate with layer eviction system
//! - audit for Tenant::cancel nor Timeline::cancel responsivity
//! - audit for Tenant/Timeline gate holding (we spawn tokio tasks during this flow!)
//!
//! An incomplete set of TODOs from the Hackathon:
//! - version-specific CheckPointData (=> pgv abstraction, already exists for regular walingest)
@@ -32,7 +44,7 @@ use pageserver_api::key::{
rel_dir_to_key, rel_size_to_key, relmap_file_key, slru_block_to_key, slru_dir_to_key,
slru_segment_size_to_key,
};
use pageserver_api::keyspace::{ShardedRange, singleton_range};
use pageserver_api::keyspace::{contiguous_range_len, is_contiguous_range, singleton_range};
use pageserver_api::models::{ShardImportProgress, ShardImportProgressV1, ShardImportStatus};
use pageserver_api::reltag::{RelTag, SlruKind};
use pageserver_api::shard::ShardIdentity;
@@ -100,7 +112,6 @@ async fn run_v1(
.unwrap(),
import_job_concurrency: base.import_job_concurrency,
import_job_checkpoint_threshold: base.import_job_checkpoint_threshold,
import_job_max_byte_range_size: base.import_job_max_byte_range_size,
}
}
None => timeline.conf.timeline_import_config.clone(),
@@ -131,15 +142,7 @@ async fn run_v1(
pausable_failpoint!("import-timeline-pre-execute-pausable");
let jobs_count = import_progress.as_ref().map(|p| p.jobs);
let start_from_job_idx = import_progress.map(|progress| progress.completed);
tracing::info!(
start_from_job_idx=?start_from_job_idx,
jobs=?jobs_count,
"Executing import plan"
);
plan.execute(timeline, start_from_job_idx, plan_hash, &import_config, ctx)
.await
}
@@ -164,7 +167,6 @@ impl Planner {
/// This function is and must remain pure: given the same input, it will generate the same import plan.
async fn plan(mut self, import_config: &TimelineImportConfig) -> anyhow::Result<Plan> {
let pgdata_lsn = Lsn(self.control_file.control_file_data().checkPoint).align();
anyhow::ensure!(pgdata_lsn.is_valid());
let datadir = PgDataDir::new(&self.storage).await?;
@@ -247,22 +249,14 @@ impl Planner {
});
// Assigns parts of key space to later parallel jobs
// Note: The image layers produced here may have gaps, meaning,
// there is not an image for each key in the layer's key range.
// The read path stops traversal at the first image layer, regardless
// of whether a base image has been found for a key or not.
// (Concept of sparse image layers doesn't exist.)
// This behavior is exactly right for the base image layers we're producing here.
// But, since no other place in the code currently produces image layers with gaps,
// it seems noteworthy.
let mut last_end_key = Key::MIN;
let mut current_chunk = Vec::new();
let mut current_chunk_size: usize = 0;
let mut jobs = Vec::new();
for task in std::mem::take(&mut self.tasks).into_iter() {
let task_size = task.total_size(&self.shard);
let projected_chunk_size = current_chunk_size.saturating_add(task_size);
if projected_chunk_size > import_config.import_job_soft_size_limit.into() {
if current_chunk_size + task.total_size()
> import_config.import_job_soft_size_limit.into()
{
let key_range = last_end_key..task.key_range().start;
jobs.push(ChunkProcessingJob::new(
key_range.clone(),
@@ -272,7 +266,7 @@ impl Planner {
last_end_key = key_range.end;
current_chunk_size = 0;
}
current_chunk_size = current_chunk_size.saturating_add(task_size);
current_chunk_size += task.total_size();
current_chunk.push(task);
}
jobs.push(ChunkProcessingJob::new(
@@ -442,7 +436,6 @@ impl Plan {
let mut last_completed_job_idx = start_after_job_idx.unwrap_or(0);
let checkpoint_every: usize = import_config.import_job_checkpoint_threshold.into();
let max_byte_range_size: usize = import_config.import_job_max_byte_range_size.into();
// Run import jobs concurrently up to the limit specified by the pageserver configuration.
// Note that we process completed futures in the oreder of insertion. This will be the
@@ -458,7 +451,7 @@ impl Plan {
work.push_back(tokio::task::spawn(async move {
let _permit = permit;
let res = job.run(job_timeline, max_byte_range_size, &ctx).await;
let res = job.run(job_timeline, &ctx).await;
(job_idx, res)
}));
},
@@ -473,8 +466,6 @@ impl Plan {
last_completed_job_idx = job_idx;
if last_completed_job_idx % checkpoint_every == 0 {
tracing::info!(last_completed_job_idx, jobs=%jobs_in_plan, "Checkpointing import status");
let progress = ShardImportProgressV1 {
jobs: jobs_in_plan,
completed: last_completed_job_idx,
@@ -613,18 +604,18 @@ impl PgDataDirDb {
};
let path = datadir_path.join(rel_tag.to_segfile_name(segno));
anyhow::ensure!(filesize % BLCKSZ as usize == 0);
assert!(filesize % BLCKSZ as usize == 0); // TODO: this should result in an error
let nblocks = filesize / BLCKSZ as usize;
Ok(PgDataDirDbFile {
PgDataDirDbFile {
path,
filesize,
rel_tag,
segno,
nblocks: Some(nblocks), // first non-cummulative sizes
})
}
})
.collect::<anyhow::Result<_, _>>()?;
.collect();
// Set cummulative sizes. Do all of that math here, so that later we could easier
// parallelize over segments and know with which segments we need to write relsize
@@ -659,29 +650,18 @@ impl PgDataDirDb {
trait ImportTask {
fn key_range(&self) -> Range<Key>;
fn total_size(&self, shard_identity: &ShardIdentity) -> usize {
let range = ShardedRange::new(self.key_range(), shard_identity);
let page_count = range.page_count();
if page_count == u32::MAX {
tracing::warn!(
"Import task has non contiguous key range: {}..{}",
self.key_range().start,
self.key_range().end
);
// Tasks should operate on contiguous ranges. It is unexpected for
// ranges to violate this assumption. Calling code handles this by mapping
// any task on a non contiguous range to its own image layer.
usize::MAX
fn total_size(&self) -> usize {
// TODO: revisit this
if is_contiguous_range(&self.key_range()) {
contiguous_range_len(&self.key_range()) as usize * 8192
} else {
page_count as usize * 8192
u32::MAX as usize
}
}
async fn doit(
self,
layer_writer: &mut ImageLayerWriter,
max_byte_range_size: usize,
ctx: &RequestContext,
) -> anyhow::Result<usize>;
}
@@ -718,7 +698,6 @@ impl ImportTask for ImportSingleKeyTask {
async fn doit(
self,
layer_writer: &mut ImageLayerWriter,
_max_byte_range_size: usize,
ctx: &RequestContext,
) -> anyhow::Result<usize> {
layer_writer.put_image(self.key, self.buf, ctx).await?;
@@ -772,7 +751,6 @@ impl ImportTask for ImportRelBlocksTask {
async fn doit(
self,
layer_writer: &mut ImageLayerWriter,
max_byte_range_size: usize,
ctx: &RequestContext,
) -> anyhow::Result<usize> {
debug!("Importing relation file");
@@ -799,7 +777,7 @@ impl ImportTask for ImportRelBlocksTask {
assert_eq!(key.len(), 1);
assert!(!acc.is_empty());
assert!(acc_end > acc_start);
if acc_end == start && end - acc_start <= max_byte_range_size {
if acc_end == start /* TODO additional max range check here, to limit memory consumption per task to X */ {
acc.push(key.pop().unwrap());
Ok((acc, acc_start, end))
} else {
@@ -814,8 +792,8 @@ impl ImportTask for ImportRelBlocksTask {
.get_range(&self.path, range_start.into_u64(), range_end.into_u64())
.await?;
let mut buf = Bytes::from(range_buf);
// TODO: batched writes
for key in keys {
// The writer buffers writes internally
let image = buf.split_to(8192);
layer_writer.put_image(key, image, ctx).await?;
nimages += 1;
@@ -863,15 +841,11 @@ impl ImportTask for ImportSlruBlocksTask {
async fn doit(
self,
layer_writer: &mut ImageLayerWriter,
_max_byte_range_size: usize,
ctx: &RequestContext,
) -> anyhow::Result<usize> {
debug!("Importing SLRU segment file {}", self.path);
let buf = self.storage.get(&self.path).await?;
// TODO(vlad): Does timestamp to LSN work for imported timelines?
// Probably not since we don't append the `xact_time` to it as in
// [`WalIngest::ingest_xact_record`].
let (kind, segno, start_blk) = self.key_range.start.to_slru_block()?;
let (_kind, _segno, end_blk) = self.key_range.end.to_slru_block()?;
let mut blknum = start_blk;
@@ -910,13 +884,12 @@ impl ImportTask for AnyImportTask {
async fn doit(
self,
layer_writer: &mut ImageLayerWriter,
max_byte_range_size: usize,
ctx: &RequestContext,
) -> anyhow::Result<usize> {
match self {
Self::SingleKey(t) => t.doit(layer_writer, max_byte_range_size, ctx).await,
Self::RelBlocks(t) => t.doit(layer_writer, max_byte_range_size, ctx).await,
Self::SlruBlocks(t) => t.doit(layer_writer, max_byte_range_size, ctx).await,
Self::SingleKey(t) => t.doit(layer_writer, ctx).await,
Self::RelBlocks(t) => t.doit(layer_writer, ctx).await,
Self::SlruBlocks(t) => t.doit(layer_writer, ctx).await,
}
}
}
@@ -957,12 +930,7 @@ impl ChunkProcessingJob {
}
}
async fn run(
self,
timeline: Arc<Timeline>,
max_byte_range_size: usize,
ctx: &RequestContext,
) -> anyhow::Result<()> {
async fn run(self, timeline: Arc<Timeline>, ctx: &RequestContext) -> anyhow::Result<()> {
let mut writer = ImageLayerWriter::new(
timeline.conf,
timeline.timeline_id,
@@ -977,7 +945,7 @@ impl ChunkProcessingJob {
let mut nimages = 0;
for task in self.tasks {
nimages += task.doit(&mut writer, max_byte_range_size, ctx).await?;
nimages += task.doit(&mut writer, ctx).await?;
}
let resident_layer = if nimages > 0 {

View File

@@ -6,7 +6,7 @@ use bytes::Bytes;
use postgres_ffi::ControlFileData;
use remote_storage::{
Download, DownloadError, DownloadKind, DownloadOpts, GenericRemoteStorage, Listing,
ListingObject, RemotePath, RemoteStorageConfig,
ListingObject, RemotePath,
};
use serde::de::DeserializeOwned;
use tokio_util::sync::CancellationToken;
@@ -22,9 +22,11 @@ pub async fn new(
location: &index_part_format::Location,
cancel: CancellationToken,
) -> Result<RemoteStorageWrapper, anyhow::Error> {
// Downloads should be reasonably sized. We do ranged reads for relblock raw data
// and full reads for SLRU segments which are bounded by Postgres.
let timeout = RemoteStorageConfig::DEFAULT_TIMEOUT;
// FIXME: we probably want some timeout, and we might be able to assume the max file
// size on S3 is 1GiB (postgres segment size). But the problem is that the individual
// downloaders don't know enough about concurrent downloads to make a guess on the
// expected bandwidth and resulting best timeout.
let timeout = std::time::Duration::from_secs(24 * 60 * 60);
let location_storage = match location {
#[cfg(feature = "testing")]
index_part_format::Location::LocalFs { path } => {
@@ -48,12 +50,9 @@ pub async fn new(
.import_pgdata_aws_endpoint_url
.clone()
.map(|url| url.to_string()), // by specifying None here, remote_storage/aws-sdk-rust will infer from env
// This matches the default import job concurrency. This is managed
// separately from the usual S3 client, but the concern here is bandwidth
// usage.
concurrency_limit: 128.try_into().unwrap(),
max_keys_per_list_response: Some(1000),
upload_storage_class: None, // irrelevant
concurrency_limit: 100.try_into().unwrap(), // TODO: think about this
max_keys_per_list_response: Some(1000), // TODO: think about this
upload_storage_class: None, // irrelevant
},
timeout,
)

View File

@@ -113,7 +113,7 @@ impl WalReceiver {
}
connection_manager_state.shutdown().await;
*loop_status.write().unwrap() = None;
info!("task exits");
debug!("task exits");
}
.instrument(info_span!(parent: None, "wal_connection_manager", tenant_id = %tenant_shard_id.tenant_id, shard_id = %tenant_shard_id.shard_slug(), timeline_id = %timeline_id))
});

View File

@@ -32,7 +32,9 @@ use utils::backoff::{
};
use utils::id::{NodeId, TenantTimelineId};
use utils::lsn::Lsn;
use utils::postgres_client::{ConnectionConfigArgs, wal_stream_connection_config};
use utils::postgres_client::{
ConnectionConfigArgs, PostgresClientProtocol, wal_stream_connection_config,
};
use super::walreceiver_connection::{WalConnectionStatus, WalReceiverError};
use super::{TaskEvent, TaskHandle, TaskStateUpdate, WalReceiverConf};
@@ -989,12 +991,19 @@ impl ConnectionManagerState {
return None; // no connection string, ignore sk
}
let shard_identity = self.timeline.get_shard_identity();
let (shard_number, shard_count, shard_stripe_size) = (
Some(shard_identity.number.0),
Some(shard_identity.count.0),
Some(shard_identity.stripe_size.0),
);
let (shard_number, shard_count, shard_stripe_size) = match self.conf.protocol {
PostgresClientProtocol::Vanilla => {
(None, None, None)
},
PostgresClientProtocol::Interpreted { .. } => {
let shard_identity = self.timeline.get_shard_identity();
(
Some(shard_identity.number.0),
Some(shard_identity.count.0),
Some(shard_identity.stripe_size.0),
)
}
};
let connection_conf_args = ConnectionConfigArgs {
protocol: self.conf.protocol,
@@ -1111,8 +1120,8 @@ impl ReconnectReason {
#[cfg(test)]
mod tests {
use pageserver_api::config::defaults::DEFAULT_WAL_RECEIVER_PROTOCOL;
use url::Host;
use utils::postgres_client::PostgresClientProtocol;
use super::*;
use crate::tenant::harness::{TIMELINE_ID, TenantHarness};
@@ -1543,11 +1552,6 @@ mod tests {
.await
.expect("Failed to create an empty timeline for dummy wal connection manager");
let protocol = PostgresClientProtocol::Interpreted {
format: utils::postgres_client::InterpretedFormat::Protobuf,
compression: Some(utils::postgres_client::Compression::Zstd { level: 1 }),
};
ConnectionManagerState {
id: TenantTimelineId {
tenant_id: harness.tenant_shard_id.tenant_id,
@@ -1556,7 +1560,7 @@ mod tests {
timeline,
cancel: CancellationToken::new(),
conf: WalReceiverConf {
protocol,
protocol: DEFAULT_WAL_RECEIVER_PROTOCOL,
wal_connect_timeout: Duration::from_secs(1),
lagging_wal_timeout: Duration::from_secs(1),
max_lsn_wal_lag: NonZeroU64::new(1024 * 1024).unwrap(),

View File

@@ -15,7 +15,7 @@ use postgres_backend::is_expected_io_error;
use postgres_connection::PgConnectionConfig;
use postgres_ffi::WAL_SEGMENT_SIZE;
use postgres_ffi::v14::xlog_utils::normalize_lsn;
use postgres_ffi::waldecoder::WalDecodeError;
use postgres_ffi::waldecoder::{WalDecodeError, WalStreamDecoder};
use postgres_protocol::message::backend::ReplicationMessage;
use postgres_types::PgLsn;
use tokio::sync::watch;
@@ -31,7 +31,7 @@ use utils::lsn::Lsn;
use utils::pageserver_feedback::PageserverFeedback;
use utils::postgres_client::PostgresClientProtocol;
use utils::sync::gate::GateError;
use wal_decoder::models::{FlushUncommittedRecords, InterpretedWalRecords};
use wal_decoder::models::{FlushUncommittedRecords, InterpretedWalRecord, InterpretedWalRecords};
use wal_decoder::wire_format::FromWireFormat;
use super::TaskStateUpdate;
@@ -275,6 +275,8 @@ pub(super) async fn handle_walreceiver_connection(
let copy_stream = replication_client.copy_both_simple(&query).await?;
let mut physical_stream = pin!(ReplicationStream::new(copy_stream));
let mut waldecoder = WalStreamDecoder::new(startpoint, timeline.pg_version);
let mut walingest = WalIngest::new(timeline.as_ref(), startpoint, &ctx)
.await
.map_err(|e| match e.kind {
@@ -282,22 +284,19 @@ pub(super) async fn handle_walreceiver_connection(
_ => WalReceiverError::Other(e.into()),
})?;
let (format, compression) = match protocol {
let shard = vec![*timeline.get_shard_identity()];
let interpreted_proto_config = match protocol {
PostgresClientProtocol::Vanilla => None,
PostgresClientProtocol::Interpreted {
format,
compression,
} => (format, compression),
PostgresClientProtocol::Vanilla => {
return Err(WalReceiverError::Other(anyhow!(
"Vanilla WAL receiver protocol is no longer supported for ingest"
)));
}
} => Some((format, compression)),
};
let mut expected_wal_start = startpoint;
while let Some(replication_message) = {
select! {
biased;
_ = cancellation.cancelled() => {
debug!("walreceiver interrupted");
None
@@ -313,6 +312,16 @@ pub(super) async fn handle_walreceiver_connection(
// Update the connection status before processing the message. If the message processing
// fails (e.g. in walingest), we still want to know latests LSNs from the safekeeper.
match &replication_message {
ReplicationMessage::XLogData(xlog_data) => {
connection_status.latest_connection_update = now;
connection_status.commit_lsn = Some(Lsn::from(xlog_data.wal_end()));
connection_status.streaming_lsn = Some(Lsn::from(
xlog_data.wal_start() + xlog_data.data().len() as u64,
));
if !xlog_data.data().is_empty() {
connection_status.latest_wal_update = now;
}
}
ReplicationMessage::PrimaryKeepAlive(keepalive) => {
connection_status.latest_connection_update = now;
connection_status.commit_lsn = Some(Lsn::from(keepalive.wal_end()));
@@ -343,6 +352,7 @@ pub(super) async fn handle_walreceiver_connection(
// were interpreted.
let streaming_lsn = Lsn::from(raw.streaming_lsn());
let (format, compression) = interpreted_proto_config.unwrap();
let batch = InterpretedWalRecords::from_wire(raw.data(), format, compression)
.await
.with_context(|| {
@@ -498,6 +508,138 @@ pub(super) async fn handle_walreceiver_connection(
Some(streaming_lsn)
}
ReplicationMessage::XLogData(xlog_data) => {
async fn commit(
modification: &mut DatadirModification<'_>,
uncommitted: &mut u64,
filtered: &mut u64,
ctx: &RequestContext,
) -> anyhow::Result<()> {
let stats = modification.stats();
modification.commit(ctx).await?;
WAL_INGEST
.records_committed
.inc_by(*uncommitted - *filtered);
WAL_INGEST.inc_values_committed(&stats);
*uncommitted = 0;
*filtered = 0;
Ok(())
}
// Pass the WAL data to the decoder, and see if we can decode
// more records as a result.
let data = xlog_data.data();
let startlsn = Lsn::from(xlog_data.wal_start());
let endlsn = startlsn + data.len() as u64;
trace!("received XLogData between {startlsn} and {endlsn}");
WAL_INGEST.bytes_received.inc_by(data.len() as u64);
waldecoder.feed_bytes(data);
{
let mut modification = timeline.begin_modification(startlsn);
let mut uncommitted_records = 0;
let mut filtered_records = 0;
while let Some((next_record_lsn, recdata)) = waldecoder.poll_decode()? {
// It is important to deal with the aligned records as lsn in getPage@LSN is
// aligned and can be several bytes bigger. Without this alignment we are
// at risk of hitting a deadlock.
if !next_record_lsn.is_aligned() {
return Err(WalReceiverError::Other(anyhow!("LSN not aligned")));
}
// Deserialize and interpret WAL record
let interpreted = InterpretedWalRecord::from_bytes_filtered(
recdata,
&shard,
next_record_lsn,
modification.tline.pg_version,
)?
.remove(timeline.get_shard_identity())
.unwrap();
if matches!(interpreted.flush_uncommitted, FlushUncommittedRecords::Yes)
&& uncommitted_records > 0
{
// Special case: legacy PG database creations operate by reading pages from a 'template' database:
// these are the only kinds of WAL record that require reading data blocks while ingesting. Ensure
// all earlier writes of data blocks are visible by committing any modification in flight.
commit(
&mut modification,
&mut uncommitted_records,
&mut filtered_records,
&ctx,
)
.await?;
}
// Ingest the records without immediately committing them.
timeline.metrics.wal_records_received.inc();
let ingested = walingest
.ingest_record(interpreted, &mut modification, &ctx)
.await
.with_context(|| {
format!("could not ingest record at {next_record_lsn}")
})
.inspect_err(|err| {
// TODO: we can't differentiate cancellation errors with
// anyhow::Error, so just ignore it if we're cancelled.
if !cancellation.is_cancelled() && !timeline.is_stopping() {
critical!("{err:?}")
}
})?;
if !ingested {
tracing::debug!("ingest: filtered out record @ LSN {next_record_lsn}");
WAL_INGEST.records_filtered.inc();
filtered_records += 1;
}
// FIXME: this cannot be made pausable_failpoint without fixing the
// failpoint library; in tests, the added amount of debugging will cause us
// to timeout the tests.
fail_point!("walreceiver-after-ingest");
last_rec_lsn = next_record_lsn;
// Commit every ingest_batch_size records. Even if we filtered out
// all records, we still need to call commit to advance the LSN.
uncommitted_records += 1;
if uncommitted_records >= ingest_batch_size
|| modification.approx_pending_bytes()
> DatadirModification::MAX_PENDING_BYTES
{
commit(
&mut modification,
&mut uncommitted_records,
&mut filtered_records,
&ctx,
)
.await?;
}
}
// Commit the remaining records.
if uncommitted_records > 0 {
commit(
&mut modification,
&mut uncommitted_records,
&mut filtered_records,
&ctx,
)
.await?;
}
}
if !caught_up && endlsn >= end_of_wal {
info!("caught up at LSN {endlsn}");
caught_up = true;
}
Some(endlsn)
}
ReplicationMessage::PrimaryKeepAlive(keepalive) => {
let wal_end = keepalive.wal_end();
let timestamp = keepalive.timestamp();

View File

@@ -1,5 +1,6 @@
# pgxs/neon/Makefile
MODULE_big = neon
OBJS = \
$(WIN32RES) \
@@ -21,8 +22,7 @@ OBJS = \
walproposer.o \
walproposer_pg.o \
control_plane_connector.o \
walsender_hooks.o \
$(LIBCOMMUNICATOR_PATH)/libcommunicator.a
walsender_hooks.o
PG_CPPFLAGS = -I$(libpq_srcdir)
SHLIB_LINK_INTERNAL = $(libpq)

View File

@@ -1,13 +0,0 @@
[package]
name = "communicator"
version = "0.1.0"
edition = "2024"
[lib]
crate-type = ["staticlib"]
[dependencies]
neon-shmem.workspace = true
[build-dependencies]
cbindgen.workspace = true

View File

@@ -1,8 +0,0 @@
This package will evolve into a "compute-pageserver communicator"
process and machinery. For now, it just provides wrappers on the
neon-shmem Rust crate, to allow using it in the C implementation of
the LFC.
At compilation time, pgxn/neon/communicator/ produces a static
library, libcommunicator.a. It is linked to the neon.so extension
library.

View File

@@ -1,22 +0,0 @@
use std::env;
fn main() -> Result<(), Box<dyn std::error::Error>> {
let crate_dir = env::var("CARGO_MANIFEST_DIR").unwrap();
cbindgen::generate(crate_dir).map_or_else(
|error| match error {
cbindgen::Error::ParseSyntaxError { .. } => {
// This means there was a syntax error in the Rust sources. Don't panic, because
// we want the build to continue and the Rust compiler to hit the error. The
// Rust compiler produces a better error message than cbindgen.
eprintln!("Generating C bindings failed because of a Rust syntax error");
}
e => panic!("Unable to generate C bindings: {:?}", e),
},
|bindings| {
bindings.write_to_file("communicator_bindings.h");
},
);
Ok(())
}

View File

@@ -1,4 +0,0 @@
language = "C"
[enum]
prefix_with_name = true

View File

@@ -1,240 +0,0 @@
//! Glue code to allow using the Rust shmem hash map implementation from C code
//!
//! For convience of adapting existing code, the interface provided somewhat resembles the dynahash
//! interface.
//!
//! NOTE: The caller is responsible for locking! The caller is expected to hold the PostgreSQL
//! LWLock, 'lfc_lock', while accessing the hash table, in shared or exclusive mode as appropriate.
use std::ffi::c_void;
use std::marker::PhantomData;
use neon_shmem::hash::entry::Entry;
use neon_shmem::hash::{HashMapAccess, HashMapInit};
use neon_shmem::shmem::ShmemHandle;
/// NB: This must match the definition of BufferTag in Postgres C headers. We could use bindgen to
/// generate this from the C headers, but prefer to not introduce dependency on bindgen for now.
///
/// Note that there are no padding bytes. If the corresponding C struct has padding bytes, the C C
/// code must clear them.
#[derive(Clone, Debug, Hash, Eq, PartialEq)]
#[repr(C)]
pub struct FileCacheKey {
pub _spc_id: u32,
pub _db_id: u32,
pub _rel_number: u32,
pub _fork_num: u32,
pub _block_num: u32,
}
/// Like with FileCacheKey, this must match the definition of FileCacheEntry in file_cache.c. We
/// don't look at the contents here though, it's sufficent that the size and alignment matches.
#[derive(Clone, Debug, Default)]
#[repr(C)]
pub struct FileCacheEntry {
pub _offset: u32,
pub _access_count: u32,
pub _prev: *mut FileCacheEntry,
pub _next: *mut FileCacheEntry,
pub _state: [u32; 8],
}
/// XXX: This could be just:
///
/// ```ignore
/// type FileCacheHashMapHandle = HashMapInit<'a, FileCacheKey, FileCacheEntry>
/// ```
///
/// but with that, cbindgen generates a broken typedef in the C header file which doesn't
/// compile. It apparently gets confused by the generics.
#[repr(transparent)]
pub struct FileCacheHashMapHandle<'a>(
pub *mut c_void,
PhantomData<HashMapInit<'a, FileCacheKey, FileCacheEntry>>,
);
impl<'a> From<Box<HashMapInit<'a, FileCacheKey, FileCacheEntry>>> for FileCacheHashMapHandle<'a> {
fn from(x: Box<HashMapInit<'a, FileCacheKey, FileCacheEntry>>) -> Self {
FileCacheHashMapHandle(Box::into_raw(x) as *mut c_void, PhantomData::default())
}
}
impl<'a> From<FileCacheHashMapHandle<'a>> for Box<HashMapInit<'a, FileCacheKey, FileCacheEntry>> {
fn from(x: FileCacheHashMapHandle) -> Self {
unsafe { Box::from_raw(x.0.cast()) }
}
}
/// XXX: same for this
#[repr(transparent)]
pub struct FileCacheHashMapAccess<'a>(
pub *mut c_void,
PhantomData<HashMapAccess<'a, FileCacheKey, FileCacheEntry>>,
);
impl<'a> From<Box<HashMapAccess<'a, FileCacheKey, FileCacheEntry>>> for FileCacheHashMapAccess<'a> {
fn from(x: Box<HashMapAccess<'a, FileCacheKey, FileCacheEntry>>) -> Self {
// Convert the Box into a raw mutable pointer to the HashMapAccess itself.
// This transfers ownership of the HashMapAccess (and its contained ShmemHandle)
// to the raw pointer. The C caller is now responsible for managing this memory.
FileCacheHashMapAccess(Box::into_raw(x) as *mut c_void, PhantomData::default())
}
}
impl<'a> FileCacheHashMapAccess<'a> {
fn as_ref(self) -> &'a HashMapAccess<'a, FileCacheKey, FileCacheEntry> {
let ptr: *mut HashMapAccess<'_, FileCacheKey, FileCacheEntry> = self.0.cast();
unsafe { ptr.as_ref().unwrap() }
}
fn as_mut(self) -> &'a mut HashMapAccess<'a, FileCacheKey, FileCacheEntry> {
let ptr: *mut HashMapAccess<'_, FileCacheKey, FileCacheEntry> = self.0.cast();
unsafe { ptr.as_mut().unwrap() }
}
}
/// Initialize the shared memory area at postmaster startup. The returned handle is inherited
/// by all the backend processes across fork()
#[unsafe(no_mangle)]
pub extern "C" fn bcomm_file_cache_shmem_init<'a>(
initial_num_buckets: u32,
max_num_buckets: u32,
) -> FileCacheHashMapHandle<'a> {
let max_bytes = HashMapInit::<FileCacheKey, FileCacheEntry>::estimate_size(max_num_buckets);
let shmem_handle =
ShmemHandle::new("lfc mapping", 0, max_bytes).expect("shmem initialization failed");
let handle = HashMapInit::<FileCacheKey, FileCacheEntry>::init_in_shmem(
initial_num_buckets,
shmem_handle,
);
Box::new(handle).into()
}
/// Initialize the access to the shared memory area in a backend process.
///
/// XXX: I'm not sure if this actually gets called in each process, or if the returned struct
/// is also inherited across fork(). It currently works either way but if this did more
/// initialization that needed to be done after fork(), then it would matter.
#[unsafe(no_mangle)]
pub extern "C" fn bcomm_file_cache_shmem_access<'a>(
handle: FileCacheHashMapHandle<'a>,
) -> FileCacheHashMapAccess<'a> {
let handle: Box<HashMapInit<'_, FileCacheKey, FileCacheEntry>> = handle.into();
Box::new(handle.attach_writer()).into()
}
/// Return the current number of buckets in the hash table
#[unsafe(no_mangle)]
pub extern "C" fn bcomm_file_cache_get_num_buckets<'a>(
map: FileCacheHashMapAccess<'static>,
) -> u32 {
let map = map.as_ref();
map.get_num_buckets().try_into().unwrap()
}
/// Look up the entry with given key and hash.
///
/// This is similar to dynahash's hash_search(... , HASH_FIND)
#[unsafe(no_mangle)]
pub extern "C" fn bcomm_file_cache_hash_find<'a>(
map: FileCacheHashMapAccess<'static>,
key: &FileCacheKey,
hash: u64,
) -> Option<&'static FileCacheEntry> {
let map = map.as_ref();
map.get_with_hash(key, hash)
}
/// Look up the entry at given bucket position
///
/// This has no direct equivalent in the dynahash interface, but can be used to
/// iterate through all entries in the hash table.
#[unsafe(no_mangle)]
pub extern "C" fn bcomm_file_cache_hash_get_at_pos<'a>(
map: FileCacheHashMapAccess<'static>,
pos: u32,
) -> Option<&'static FileCacheEntry> {
let map = map.as_ref();
map.get_at_bucket(pos as usize).map(|(_k, v)| v)
}
/// Remove entry, given a pointer to the value.
///
/// This is equivalent to dynahash hash_search(entry->key, HASH_REMOVE), where 'entry'
/// is an entry you have previously looked up
#[unsafe(no_mangle)]
pub extern "C" fn bcomm_file_cache_hash_remove_entry<'a, 'b>(
map: FileCacheHashMapAccess,
entry: *mut FileCacheEntry,
) {
let map = map.as_mut();
let pos = map.get_bucket_for_value(entry);
match map.entry_at_bucket(pos) {
Some(e) => {
e.remove();
}
None => {
// todo: shouldn't happen, panic?
}
}
}
/// Compute the hash for given key
///
/// This is equivalent to dynahash get_hash_value() function. We use Rust's default hasher
/// for calculating the hash though.
#[unsafe(no_mangle)]
pub extern "C" fn bcomm_file_cache_get_hash_value<'a, 'b>(
map: FileCacheHashMapAccess<'static>,
key: &FileCacheKey,
) -> u64 {
map.as_ref().get_hash_value(key)
}
/// Insert a new entry to the hash table
///
/// This is equivalent to dynahash hash_search(..., HASH_ENTER).
#[unsafe(no_mangle)]
pub extern "C" fn bcomm_file_cache_hash_enter<'a, 'b>(
map: FileCacheHashMapAccess,
key: &FileCacheKey,
hash: u64,
found: &mut bool,
) -> *mut FileCacheEntry {
match map.as_mut().entry_with_hash(key.clone(), hash) {
Entry::Occupied(mut e) => {
*found = true;
e.get_mut()
}
Entry::Vacant(e) => {
*found = false;
let initial_value = FileCacheEntry::default();
e.insert(initial_value).expect("TODO: hash table full")
}
}
}
/// Get the key for a given entry, which must be present in the hash table.
///
/// Dynahash requires the key to be part of the "value" struct, so you can always
/// access the key with something like `entry->key`. The Rust implementation however
/// stores the key separately. This function extracts the separately stored key.
#[unsafe(no_mangle)]
pub extern "C" fn bcomm_file_cache_hash_get_key_for_entry<'a, 'b>(
map: FileCacheHashMapAccess,
entry: *const FileCacheEntry,
) -> Option<&FileCacheKey> {
let map = map.as_ref();
let pos = map.get_bucket_for_value(entry);
map.get_at_bucket(pos as usize).map(|(k, _v)| k)
}
/// Remove all entries from the hash table
#[unsafe(no_mangle)]
pub extern "C" fn bcomm_file_cache_hash_reset<'a, 'b>(map: FileCacheHashMapAccess) {
let map = map.as_mut();
let num_buckets = map.get_num_buckets();
for i in 0..num_buckets {
if let Some(e) = map.entry_at_bucket(i) {
e.remove();
}
}
}

View File

@@ -1 +0,0 @@
pub mod file_cache_hashmap;

View File

@@ -21,7 +21,7 @@
#include "access/xlog.h"
#include "funcapi.h"
#include "miscadmin.h"
#include "common/file_utils.h"
#include "common/hashfn.h"
#include "pgstat.h"
#include "port/pg_iovec.h"
#include "postmaster/bgworker.h"
@@ -36,6 +36,7 @@
#include "storage/procsignal.h"
#include "tcop/tcopprot.h"
#include "utils/builtins.h"
#include "utils/dynahash.h"
#include "utils/guc.h"
#if PG_VERSION_NUM >= 150000
@@ -45,7 +46,6 @@
#include "hll.h"
#include "bitmap.h"
#include "file_cache.h"
#include "file_cache_rust_hash.h"
#include "neon.h"
#include "neon_lwlsncache.h"
#include "neon_perf_counters.h"
@@ -64,7 +64,7 @@
*
* Cache is always reconstructed at node startup, so we do not need to save mapping somewhere and worry about
* its consistency.
*
*
* ## Holes
*
@@ -76,15 +76,13 @@
* fallocate(FALLOC_FL_PUNCH_HOLE) call. The nominal size of the file doesn't
* shrink, but the disk space it uses does.
*
* Each hole is tracked in a freelist. The freelist consists of two parts: a
* fixed-size array in shared memory, and a linked chain of on-disk
* blocks. When the in-memory array fills up, it's flushed to a new on-disk
* chunk. If the soft limit is raised again, we reuse the holes before
* extending the nominal size of the file.
*
* The in-memory freelist array is protected by 'lfc_lock', while the on-disk
* chain is protected by a separate 'lfc_freelist_lock'. Locking rule to
* avoid deadlocks: always acquire lfc_freelist_lock first, then lfc_lock.
* Each hole is tracked by a dummy FileCacheEntry, which are kept in the
* 'holes' linked list. They are entered into the chunk hash table, with a
* special key where the blockNumber is used to store the 'offset' of the
* hole, and all other fields are zero. Holes are never looked up in the hash
* table, we only enter them there to have a FileCacheEntry that we can keep
* in the linked list. If the soft limit is raised again, we reuse the holes
* before extending the nominal size of the file.
*/
/* Local file storage allocation chunk.
@@ -94,15 +92,13 @@
* 1Mb chunks can reduce hash map size to 320Mb.
* 2. Improve access locality, subsequent pages will be allocated together improving seqscan speed
*/
#define BLOCKS_PER_CHUNK_LOG 7 /* 1Mb chunk */
#define BLOCKS_PER_CHUNK (1 << BLOCKS_PER_CHUNK_LOG)
#define MAX_BLOCKS_PER_CHUNK_LOG 7 /* 1Mb chunk */
#define MAX_BLOCKS_PER_CHUNK (1 << MAX_BLOCKS_PER_CHUNK_LOG)
#define MB ((uint64)1024*1024)
#define SIZE_MB_TO_CHUNKS(size) ((uint32)((size) * MB / BLCKSZ >> BLOCKS_PER_CHUNK_LOG))
#define BLOCK_TO_CHUNK_OFF(blkno) ((blkno) & (BLOCKS_PER_CHUNK-1))
#define INVALID_OFFSET (0xffffffff)
#define SIZE_MB_TO_CHUNKS(size) ((uint32)((size) * MB / BLCKSZ >> lfc_chunk_size_log))
#define BLOCK_TO_CHUNK_OFF(blkno) ((blkno) & (lfc_blocks_per_chunk-1))
/*
* Blocks are read or written to LFC file outside LFC critical section.
@@ -123,18 +119,15 @@ typedef enum FileCacheBlockState
typedef struct FileCacheEntry
{
BufferTag key;
uint32 hash;
uint32 offset;
uint32 access_count;
dlist_node list_node; /* LRU list node */
uint32 state[(BLOCKS_PER_CHUNK * 2 + 31) / 32]; /* two bits per block */
dlist_node list_node; /* LRU/holes list node */
uint32 state[FLEXIBLE_ARRAY_MEMBER]; /* two bits per block */
} FileCacheEntry;
/* Todo: alignment must be the same too */
StaticAssertDecl(sizeof(FileCacheEntry) == sizeof(RustFileCacheEntry),
"Rust and C declarations of FileCacheEntry are incompatible");
StaticAssertDecl(sizeof(BufferTag) == sizeof(RustFileCacheKey),
"Rust and C declarations of FileCacheKey are incompatible");
#define FILE_CACHE_ENRTY_SIZE MAXALIGN(offsetof(FileCacheEntry, state) + (lfc_blocks_per_chunk*2+31)/32*4)
#define GET_STATE(entry, i) (((entry)->state[(i) / 16] >> ((i) % 16 * 2)) & 3)
#define SET_STATE(entry, i, new_state) (entry)->state[(i) / 16] = ((entry)->state[(i) / 16] & ~(3 << ((i) % 16 * 2))) | ((new_state) << ((i) % 16 * 2))
@@ -143,9 +136,6 @@ StaticAssertDecl(sizeof(BufferTag) == sizeof(RustFileCacheKey),
#define MAX_PREWARM_WORKERS 8
#define FREELIST_ENTRIES_PER_CHUNK (BLOCKS_PER_CHUNK * BLCKSZ / sizeof(uint32) - 2)
typedef struct PrewarmWorkerState
{
uint32 prewarmed_pages;
@@ -171,6 +161,7 @@ typedef struct FileCacheControl
uint64 evicted_pages; /* number of evicted pages */
dlist_head lru; /* double linked list for LRU replacement
* algorithm */
dlist_head holes; /* double linked list of punched holes */
HyperLogLogState wss_estimation; /* estimation of working set size */
ConditionVariable cv[N_COND_VARS]; /* turnstile of condition variables */
PrewarmWorkerState prewarm_workers[MAX_PREWARM_WORKERS];
@@ -181,39 +172,23 @@ typedef struct FileCacheControl
bool prewarm_active;
bool prewarm_canceled;
dsm_handle prewarm_lfc_state_handle;
/*
* Free list. This is large enough to hold one chunks worth of entries.
*/
uint32 freelist_size;
uint32 freelist_head;
uint32 num_free_pages;
uint32 free_pages[FREELIST_ENTRIES_PER_CHUNK];
} FileCacheControl;
typedef struct FreeListChunk
{
uint32 next;
uint32 num_free_pages;
uint32 free_pages[FREELIST_ENTRIES_PER_CHUNK];
} FreeListChunk;
#define FILE_CACHE_STATE_MAGIC 0xfcfcfcfc
#define FILE_CACHE_STATE_BITMAP(fcs) ((uint8*)&(fcs)->chunks[(fcs)->n_chunks])
#define FILE_CACHE_STATE_SIZE_FOR_CHUNKS(n_chunks) (sizeof(FileCacheState) + (n_chunks)*sizeof(BufferTag) + (((n_chunks) * BLOCKS_PER_CHUNK)+7)/8)
#define FILE_CACHE_STATE_SIZE_FOR_CHUNKS(n_chunks) (sizeof(FileCacheState) + (n_chunks)*sizeof(BufferTag) + (((n_chunks) * lfc_blocks_per_chunk)+7)/8)
#define FILE_CACHE_STATE_SIZE(fcs) (sizeof(FileCacheState) + (fcs->n_chunks)*sizeof(BufferTag) + (((fcs->n_chunks) << fcs->chunk_size_log)+7)/8)
static FileCacheHashMapHandle lfc_hash_handle;
static FileCacheHashMapAccess lfc_hash;
static HTAB *lfc_hash;
static int lfc_desc = -1;
static LWLockId lfc_lock;
static LWLockId lfc_freelist_lock;
static int lfc_max_size;
static int lfc_size_limit;
static int lfc_prewarm_limit;
static int lfc_prewarm_batch;
static int lfc_blocks_per_chunk_ro = BLOCKS_PER_CHUNK;
static int lfc_chunk_size_log = MAX_BLOCKS_PER_CHUNK_LOG;
static int lfc_blocks_per_chunk = MAX_BLOCKS_PER_CHUNK;
static char *lfc_path;
static uint64 lfc_generation;
static FileCacheControl *lfc_ctl;
@@ -230,11 +205,6 @@ bool AmPrewarmWorker;
#define LFC_ENABLED() (lfc_ctl->limit != 0)
static bool freelist_push(uint32 offset);
static bool freelist_prepare_pop(void);
static uint32 freelist_pop(void);
static bool freelist_is_empty(void);
/*
* Close LFC file if opened.
* All backends should close their LFC files once LFC is disabled.
@@ -262,9 +232,15 @@ lfc_switch_off(void)
if (LFC_ENABLED())
{
/* Invalidate hash */
file_cache_hash_reset(lfc_hash);
HASH_SEQ_STATUS status;
FileCacheEntry *entry;
/* Invalidate hash */
hash_seq_init(&status, lfc_hash);
while ((entry = hash_seq_search(&status)) != NULL)
{
hash_search_with_hash_value(lfc_hash, &entry->key, entry->hash, HASH_REMOVE, NULL);
}
lfc_ctl->generation += 1;
lfc_ctl->size = 0;
lfc_ctl->pinned = 0;
@@ -272,9 +248,7 @@ lfc_switch_off(void)
lfc_ctl->used_pages = 0;
lfc_ctl->limit = 0;
dlist_init(&lfc_ctl->lru);
lfc_ctl->freelist_head = INVALID_OFFSET;
lfc_ctl->num_free_pages = 0;
dlist_init(&lfc_ctl->holes);
/*
* We need to use unlink to to avoid races in LFC write, because it is not
@@ -343,8 +317,8 @@ lfc_ensure_opened(void)
static void
lfc_shmem_startup(void)
{
size_t size;
bool found;
static HASHCTL info;
if (prev_shmem_startup_hook)
{
@@ -353,29 +327,27 @@ lfc_shmem_startup(void)
LWLockAcquire(AddinShmemInitLock, LW_EXCLUSIVE);
size = sizeof(FileCacheControl);
lfc_ctl = (FileCacheControl *) ShmemInitStruct("lfc", size, &found);
lfc_ctl = (FileCacheControl *) ShmemInitStruct("lfc", sizeof(FileCacheControl), &found);
if (!found)
{
int fd;
uint32 n_chunks = SIZE_MB_TO_CHUNKS(lfc_max_size);
lfc_lock = (LWLockId) GetNamedLWLockTranche("lfc_lock");
lfc_freelist_lock = (LWLockId) GetNamedLWLockTranche("lfc_freelist_lock");
info.keysize = sizeof(BufferTag);
info.entrysize = FILE_CACHE_ENRTY_SIZE;
/*
* n_chunks+1 because we add new element to hash table before eviction
* of victim
*/
lfc_hash_handle = file_cache_hash_shmem_init(n_chunks + 1, n_chunks + 1);
memset(lfc_ctl, 0, offsetof(FileCacheControl, free_pages));
lfc_hash = ShmemInitHash("lfc_hash",
n_chunks + 1, n_chunks + 1,
&info,
HASH_ELEM | HASH_BLOBS);
memset(lfc_ctl, 0, sizeof(FileCacheControl));
dlist_init(&lfc_ctl->lru);
lfc_ctl->freelist_size = FREELIST_ENTRIES_PER_CHUNK;
lfc_ctl->freelist_head = INVALID_OFFSET;
lfc_ctl->num_free_pages = 0;
dlist_init(&lfc_ctl->holes);
/* Initialize hyper-log-log structure for estimating working set size */
initSHLL(&lfc_ctl->wss_estimation);
@@ -399,25 +371,18 @@ lfc_shmem_startup(void)
}
LWLockRelease(AddinShmemInitLock);
lfc_hash = file_cache_hash_shmem_access(lfc_hash_handle);
}
static void
lfc_shmem_request(void)
{
size_t size;
#if PG_VERSION_NUM>=150000
if (prev_shmem_request_hook)
prev_shmem_request_hook();
#endif
size = sizeof(FileCacheControl);
RequestAddinShmemSpace(size);
RequestAddinShmemSpace(sizeof(FileCacheControl) + hash_estimate_size(SIZE_MB_TO_CHUNKS(lfc_max_size) + 1, FILE_CACHE_ENRTY_SIZE));
RequestNamedLWLockTranche("lfc_lock", 1);
RequestNamedLWLockTranche("lfc_freelist_lock", 2);
}
static bool
@@ -433,6 +398,24 @@ is_normal_backend(void)
return lfc_ctl && MyProc && UsedShmemSegAddr && !IsParallelWorker();
}
static bool
lfc_check_chunk_size(int *newval, void **extra, GucSource source)
{
if (*newval & (*newval - 1))
{
elog(ERROR, "LFC chunk size should be power of two");
return false;
}
return true;
}
static void
lfc_change_chunk_size(int newval, void* extra)
{
lfc_chunk_size_log = pg_ceil_log2_32(newval);
}
static bool
lfc_check_limit_hook(int *newval, void **extra, GucSource source)
{
@@ -452,14 +435,12 @@ lfc_change_limit_hook(int newval, void *extra)
if (!lfc_ctl || !is_normal_backend())
return;
LWLockAcquire(lfc_freelist_lock, LW_EXCLUSIVE);
LWLockAcquire(lfc_lock, LW_EXCLUSIVE);
/* Open LFC file only if LFC was enabled or we are going to reenable it */
if (newval == 0 && !LFC_ENABLED())
{
LWLockRelease(lfc_lock);
LWLockRelease(lfc_freelist_lock);
/* File should be reopened if LFC is reenabled */
lfc_close_file();
return;
@@ -468,7 +449,6 @@ lfc_change_limit_hook(int newval, void *extra)
if (!lfc_ensure_opened())
{
LWLockRelease(lfc_lock);
LWLockRelease(lfc_freelist_lock);
return;
}
@@ -484,30 +464,35 @@ lfc_change_limit_hook(int newval, void *extra)
* returning their space to file system
*/
FileCacheEntry *victim = dlist_container(FileCacheEntry, list_node, dlist_pop_head_node(&lfc_ctl->lru));
FileCacheEntry *hole;
uint32 offset = victim->offset;
uint32 hash;
bool found;
BufferTag holetag;
CriticalAssert(victim->access_count == 0);
#ifdef FALLOC_FL_PUNCH_HOLE
if (fallocate(lfc_desc, FALLOC_FL_PUNCH_HOLE | FALLOC_FL_KEEP_SIZE, (off_t) victim->offset * BLOCKS_PER_CHUNK * BLCKSZ, BLOCKS_PER_CHUNK * BLCKSZ) < 0)
if (fallocate(lfc_desc, FALLOC_FL_PUNCH_HOLE | FALLOC_FL_KEEP_SIZE, (off_t) victim->offset * lfc_blocks_per_chunk * BLCKSZ, lfc_blocks_per_chunk * BLCKSZ) < 0)
neon_log(LOG, "Failed to punch hole in file: %m");
#endif
/* We remove the entry, and enter a hole to the freelist */
for (int i = 0; i < BLOCKS_PER_CHUNK; i++)
/* We remove the old entry, and re-enter a hole to the hash table */
for (int i = 0; i < lfc_blocks_per_chunk; i++)
{
bool is_page_cached = GET_STATE(victim, i) == AVAILABLE;
lfc_ctl->used_pages -= is_page_cached;
lfc_ctl->evicted_pages += is_page_cached;
}
file_cache_hash_remove_entry(lfc_hash, victim);
hash_search_with_hash_value(lfc_hash, &victim->key, victim->hash, HASH_REMOVE, NULL);
if (!freelist_push(offset))
{
/* freelist_push already logged the error */
lfc_switch_off();
LWLockRelease(lfc_lock);
LWLockRelease(lfc_freelist_lock);
return;
}
memset(&holetag, 0, sizeof(holetag));
holetag.blockNum = offset;
hash = get_hash_value(lfc_hash, &holetag);
hole = hash_search_with_hash_value(lfc_hash, &holetag, hash, HASH_ENTER, &found);
hole->hash = hash;
hole->offset = offset;
hole->access_count = 0;
CriticalAssert(!found);
dlist_push_tail(&lfc_ctl->holes, &hole->list_node);
lfc_ctl->used -= 1;
}
@@ -519,7 +504,6 @@ lfc_change_limit_hook(int newval, void *extra)
neon_log(DEBUG1, "set local file cache limit to %d", new_size);
LWLockRelease(lfc_lock);
LWLockRelease(lfc_freelist_lock);
}
void
@@ -595,14 +579,14 @@ lfc_init(void)
DefineCustomIntVariable("neon.file_cache_chunk_size",
"LFC chunk size in blocks (should be power of two)",
NULL,
&lfc_blocks_per_chunk_ro,
BLOCKS_PER_CHUNK,
BLOCKS_PER_CHUNK,
BLOCKS_PER_CHUNK,
PGC_INTERNAL,
&lfc_blocks_per_chunk,
MAX_BLOCKS_PER_CHUNK,
1,
MAX_BLOCKS_PER_CHUNK,
PGC_POSTMASTER,
GUC_UNIT_BLOCKS,
NULL,
NULL,
lfc_check_chunk_size,
lfc_change_chunk_size,
NULL);
DefineCustomIntVariable("neon.file_cache_prewarm_limit",
@@ -665,19 +649,19 @@ lfc_get_state(size_t max_entries)
fcs = (FileCacheState*)palloc0(state_size);
SET_VARSIZE(fcs, state_size);
fcs->magic = FILE_CACHE_STATE_MAGIC;
fcs->chunk_size_log = BLOCKS_PER_CHUNK_LOG;
fcs->chunk_size_log = lfc_chunk_size_log;
fcs->n_chunks = n_entries;
bitmap = FILE_CACHE_STATE_BITMAP(fcs);
dlist_reverse_foreach(iter, &lfc_ctl->lru)
{
FileCacheEntry *entry = dlist_container(FileCacheEntry, list_node, iter.cur);
fcs->chunks[i] = *file_cache_hash_get_key_for_entry(lfc_hash, entry);
for (int j = 0; j < BLOCKS_PER_CHUNK; j++)
fcs->chunks[i] = entry->key;
for (int j = 0; j < lfc_blocks_per_chunk; j++)
{
if (GET_STATE(entry, j) != UNAVAILABLE)
{
BITMAP_SET(bitmap, i*BLOCKS_PER_CHUNK + j);
BITMAP_SET(bitmap, i*lfc_blocks_per_chunk + j);
n_pages += 1;
}
}
@@ -686,7 +670,7 @@ lfc_get_state(size_t max_entries)
}
Assert(i == n_entries);
fcs->n_pages = n_pages;
Assert(pg_popcount((char*)bitmap, ((n_entries << BLOCKS_PER_CHUNK_LOG) + 7)/8) == n_pages);
Assert(pg_popcount((char*)bitmap, ((n_entries << lfc_chunk_size_log) + 7)/8) == n_pages);
elog(LOG, "LFC: save state of %d chunks %d pages", (int)n_entries, (int)n_pages);
}
@@ -742,7 +726,7 @@ lfc_prewarm(FileCacheState* fcs, uint32 n_workers)
}
fcs_chunk_size_log = fcs->chunk_size_log;
if (fcs_chunk_size_log > BLOCKS_PER_CHUNK_LOG)
if (fcs_chunk_size_log > MAX_BLOCKS_PER_CHUNK_LOG)
{
elog(ERROR, "LFC: Invalid chunk size log: %u", fcs->chunk_size_log);
}
@@ -961,7 +945,7 @@ lfc_invalidate(NRelFileInfo rinfo, ForkNumber forkNum, BlockNumber nblocks)
{
BufferTag tag;
FileCacheEntry *entry;
uint64 hash;
uint32 hash;
if (lfc_maybe_disabled()) /* fast exit if file cache is disabled */
return;
@@ -974,14 +958,14 @@ lfc_invalidate(NRelFileInfo rinfo, ForkNumber forkNum, BlockNumber nblocks)
LWLockAcquire(lfc_lock, LW_EXCLUSIVE);
if (LFC_ENABLED())
{
for (BlockNumber blkno = 0; blkno < nblocks; blkno += BLOCKS_PER_CHUNK)
for (BlockNumber blkno = 0; blkno < nblocks; blkno += lfc_blocks_per_chunk)
{
tag.blockNum = blkno;
hash = file_cache_hash_get_hash_value(lfc_hash, &tag);
entry = file_cache_hash_find(lfc_hash, &tag, hash);
hash = get_hash_value(lfc_hash, &tag);
entry = hash_search_with_hash_value(lfc_hash, &tag, hash, HASH_FIND, NULL);
if (entry != NULL)
{
for (int i = 0; i < BLOCKS_PER_CHUNK; i++)
for (int i = 0; i < lfc_blocks_per_chunk; i++)
{
if (GET_STATE(entry, i) == AVAILABLE)
{
@@ -1006,7 +990,7 @@ lfc_cache_contains(NRelFileInfo rinfo, ForkNumber forkNum, BlockNumber blkno)
FileCacheEntry *entry;
int chunk_offs = BLOCK_TO_CHUNK_OFF(blkno);
bool found = false;
uint64 hash;
uint32 hash;
if (lfc_maybe_disabled()) /* fast exit if file cache is disabled */
return false;
@@ -1016,12 +1000,12 @@ lfc_cache_contains(NRelFileInfo rinfo, ForkNumber forkNum, BlockNumber blkno)
tag.blockNum = blkno - chunk_offs;
CriticalAssert(BufTagGetRelNumber(&tag) != InvalidRelFileNumber);
hash = file_cache_hash_get_hash_value(lfc_hash, &tag);
hash = get_hash_value(lfc_hash, &tag);
LWLockAcquire(lfc_lock, LW_SHARED);
if (LFC_ENABLED())
{
entry = file_cache_hash_find(lfc_hash, &tag, hash);
entry = hash_search_with_hash_value(lfc_hash, &tag, hash, HASH_FIND, NULL);
found = entry != NULL && GET_STATE(entry, chunk_offs) != UNAVAILABLE;
}
LWLockRelease(lfc_lock);
@@ -1040,7 +1024,7 @@ lfc_cache_containsv(NRelFileInfo rinfo, ForkNumber forkNum, BlockNumber blkno,
FileCacheEntry *entry;
uint32 chunk_offs;
int found = 0;
uint64 hash;
uint32 hash;
int i = 0;
if (lfc_maybe_disabled()) /* fast exit if file cache is disabled */
@@ -1053,7 +1037,7 @@ lfc_cache_containsv(NRelFileInfo rinfo, ForkNumber forkNum, BlockNumber blkno,
chunk_offs = BLOCK_TO_CHUNK_OFF(blkno);
tag.blockNum = blkno - chunk_offs;
hash = file_cache_hash_get_hash_value(lfc_hash, &tag);
hash = get_hash_value(lfc_hash, &tag);
LWLockAcquire(lfc_lock, LW_SHARED);
@@ -1064,12 +1048,12 @@ lfc_cache_containsv(NRelFileInfo rinfo, ForkNumber forkNum, BlockNumber blkno,
}
while (true)
{
int this_chunk = Min(nblocks - i, BLOCKS_PER_CHUNK - chunk_offs);
entry = file_cache_hash_find(lfc_hash, &tag, hash);
int this_chunk = Min(nblocks - i, lfc_blocks_per_chunk - chunk_offs);
entry = hash_search_with_hash_value(lfc_hash, &tag, hash, HASH_FIND, NULL);
if (entry != NULL)
{
for (; chunk_offs < BLOCKS_PER_CHUNK && i < nblocks; chunk_offs++, i++)
for (; chunk_offs < lfc_blocks_per_chunk && i < nblocks; chunk_offs++, i++)
{
if (GET_STATE(entry, chunk_offs) != UNAVAILABLE)
{
@@ -1095,7 +1079,7 @@ lfc_cache_containsv(NRelFileInfo rinfo, ForkNumber forkNum, BlockNumber blkno,
*/
chunk_offs = BLOCK_TO_CHUNK_OFF(blkno + i);
tag.blockNum = (blkno + i) - chunk_offs;
hash = file_cache_hash_get_hash_value(lfc_hash, &tag);
hash = get_hash_value(lfc_hash, &tag);
}
LWLockRelease(lfc_lock);
@@ -1144,7 +1128,7 @@ lfc_readv_select(NRelFileInfo rinfo, ForkNumber forkNum, BlockNumber blkno,
BufferTag tag;
FileCacheEntry *entry;
ssize_t rc;
uint64 hash;
uint32 hash;
uint64 generation;
uint32 entry_offset;
int blocks_read = 0;
@@ -1170,9 +1154,9 @@ lfc_readv_select(NRelFileInfo rinfo, ForkNumber forkNum, BlockNumber blkno,
while (nblocks > 0)
{
struct iovec iov[PG_IOV_MAX];
uint8 chunk_mask[BLOCKS_PER_CHUNK / 8] = {0};
uint8 chunk_mask[MAX_BLOCKS_PER_CHUNK / 8] = {0};
int chunk_offs = BLOCK_TO_CHUNK_OFF(blkno);
int blocks_in_chunk = Min(nblocks, BLOCKS_PER_CHUNK - chunk_offs);
int blocks_in_chunk = Min(nblocks, lfc_blocks_per_chunk - chunk_offs);
int iteration_hits = 0;
int iteration_misses = 0;
uint64 io_time_us = 0;
@@ -1222,7 +1206,7 @@ lfc_readv_select(NRelFileInfo rinfo, ForkNumber forkNum, BlockNumber blkno,
Assert(iov_last_used - first_block_in_chunk_read >= n_blocks_to_read);
tag.blockNum = blkno - chunk_offs;
hash = file_cache_hash_get_hash_value(lfc_hash, &tag);
hash = get_hash_value(lfc_hash, &tag);
cv = &lfc_ctl->cv[hash % N_COND_VARS];
LWLockAcquire(lfc_lock, LW_EXCLUSIVE);
@@ -1235,13 +1219,13 @@ lfc_readv_select(NRelFileInfo rinfo, ForkNumber forkNum, BlockNumber blkno,
return blocks_read;
}
entry = file_cache_hash_find(lfc_hash, &tag, hash);
entry = hash_search_with_hash_value(lfc_hash, &tag, hash, HASH_FIND, NULL);
/* Approximate working set for the blocks assumed in this entry */
for (int i = 0; i < blocks_in_chunk; i++)
{
tag.blockNum = blkno + i;
addSHLL(&lfc_ctl->wss_estimation, file_cache_hash_get_hash_value(lfc_hash, &tag));
addSHLL(&lfc_ctl->wss_estimation, hash_bytes((uint8_t const*)&tag, sizeof(tag)));
}
if (entry == NULL)
@@ -1312,7 +1296,7 @@ lfc_readv_select(NRelFileInfo rinfo, ForkNumber forkNum, BlockNumber blkno,
if (iteration_hits != 0)
{
/* chunk offset (# of pages) into the LFC file */
off_t first_read_offset = (off_t) entry_offset * BLOCKS_PER_CHUNK;
off_t first_read_offset = (off_t) entry_offset * lfc_blocks_per_chunk;
int nwrite = iov_last_used - first_block_in_chunk_read;
/* offset of first IOV */
first_read_offset += chunk_offs + first_block_in_chunk_read;
@@ -1389,14 +1373,14 @@ lfc_readv_select(NRelFileInfo rinfo, ForkNumber forkNum, BlockNumber blkno,
* Returns false if there are no unpinned entries and chunk can not be added.
*/
static bool
lfc_init_new_entry(FileCacheEntry *entry)
lfc_init_new_entry(FileCacheEntry* entry, uint32 hash)
{
/*-----------
* If the chunk wasn't already in the LFC then we have these
* options, in order of preference:
*
* Unless there is no space available, we can:
* 1. Use an entry from the freelist, and
* 1. Use an entry from the `holes` list, and
* 2. Create a new entry.
* We can always, regardless of space in the LFC:
* 3. evict an entry from LRU, and
@@ -1404,10 +1388,17 @@ lfc_init_new_entry(FileCacheEntry *entry)
*/
if (lfc_ctl->used < lfc_ctl->limit)
{
if (!freelist_is_empty())
if (!dlist_is_empty(&lfc_ctl->holes))
{
/* We can reuse a hole that was left behind when the LFC was shrunk previously */
uint32 offset = freelist_pop();
FileCacheEntry *hole = dlist_container(FileCacheEntry, list_node,
dlist_pop_head_node(&lfc_ctl->holes));
uint32 offset = hole->offset;
bool hole_found;
hash_search_with_hash_value(lfc_hash, &hole->key,
hole->hash, HASH_REMOVE, &hole_found);
CriticalAssert(hole_found);
lfc_ctl->used += 1;
entry->offset = offset; /* reuse the hole */
@@ -1436,7 +1427,7 @@ lfc_init_new_entry(FileCacheEntry *entry)
FileCacheEntry *victim = dlist_container(FileCacheEntry, list_node,
dlist_pop_head_node(&lfc_ctl->lru));
for (int i = 0; i < BLOCKS_PER_CHUNK; i++)
for (int i = 0; i < lfc_blocks_per_chunk; i++)
{
bool is_page_cached = GET_STATE(victim, i) == AVAILABLE;
lfc_ctl->used_pages -= is_page_cached;
@@ -1445,21 +1436,24 @@ lfc_init_new_entry(FileCacheEntry *entry)
CriticalAssert(victim->access_count == 0);
entry->offset = victim->offset; /* grab victim's chunk */
file_cache_hash_remove_entry(lfc_hash, victim);
hash_search_with_hash_value(lfc_hash, &victim->key,
victim->hash, HASH_REMOVE, NULL);
neon_log(DEBUG2, "Swap file cache page");
}
else
{
/* Can't add this chunk - we don't have the space for it */
file_cache_hash_remove_entry(lfc_hash, entry);
hash_search_with_hash_value(lfc_hash, &entry->key, hash,
HASH_REMOVE, NULL);
lfc_ctl->prewarm_canceled = true; /* cancel prewarm if LFC limit is reached */
return false;
}
entry->access_count = 1;
entry->hash = hash;
lfc_ctl->pinned += 1;
for (int i = 0; i < BLOCKS_PER_CHUNK; i++)
for (int i = 0; i < lfc_blocks_per_chunk; i++)
SET_STATE(entry, i, UNAVAILABLE);
return true;
@@ -1496,7 +1490,7 @@ lfc_prefetch(NRelFileInfo rinfo, ForkNumber forknum, BlockNumber blkno,
FileCacheEntry *entry;
ssize_t rc;
bool found;
uint64 hash;
uint32 hash;
uint64 generation;
uint32 entry_offset;
instr_time io_start, io_end;
@@ -1515,10 +1509,9 @@ lfc_prefetch(NRelFileInfo rinfo, ForkNumber forknum, BlockNumber blkno,
CriticalAssert(BufTagGetRelNumber(&tag) != InvalidRelFileNumber);
tag.blockNum = blkno - chunk_offs;
hash = file_cache_hash_get_hash_value(lfc_hash, &tag);
hash = get_hash_value(lfc_hash, &tag);
cv = &lfc_ctl->cv[hash % N_COND_VARS];
retry:
LWLockAcquire(lfc_lock, LW_EXCLUSIVE);
if (!LFC_ENABLED() || !lfc_ensure_opened())
@@ -1527,9 +1520,6 @@ lfc_prefetch(NRelFileInfo rinfo, ForkNumber forknum, BlockNumber blkno,
return false;
}
if (!freelist_prepare_pop())
goto retry;
lwlsn = neon_get_lwlsn(rinfo, forknum, blkno);
if (lwlsn > lsn)
@@ -1540,12 +1530,12 @@ lfc_prefetch(NRelFileInfo rinfo, ForkNumber forknum, BlockNumber blkno,
return false;
}
entry = file_cache_hash_enter(lfc_hash, &tag, hash, &found);
entry = hash_search_with_hash_value(lfc_hash, &tag, hash, HASH_ENTER, &found);
if (lfc_prewarm_update_ws_estimation)
{
tag.blockNum = blkno;
addSHLL(&lfc_ctl->wss_estimation, file_cache_hash_get_hash_value(lfc_hash, &tag));
addSHLL(&lfc_ctl->wss_estimation, hash_bytes((uint8_t const*)&tag, sizeof(tag)));
}
if (found)
{
@@ -1567,7 +1557,7 @@ lfc_prefetch(NRelFileInfo rinfo, ForkNumber forknum, BlockNumber blkno,
}
else
{
if (!lfc_init_new_entry(entry))
if (!lfc_init_new_entry(entry, hash))
{
/*
* We can't process this chunk due to lack of space in LFC,
@@ -1588,7 +1578,7 @@ lfc_prefetch(NRelFileInfo rinfo, ForkNumber forknum, BlockNumber blkno,
pgstat_report_wait_start(WAIT_EVENT_NEON_LFC_WRITE);
INSTR_TIME_SET_CURRENT(io_start);
rc = pwrite(lfc_desc, buffer, BLCKSZ,
((off_t) entry_offset * BLOCKS_PER_CHUNK + chunk_offs) * BLCKSZ);
((off_t) entry_offset * lfc_blocks_per_chunk + chunk_offs) * BLCKSZ);
INSTR_TIME_SET_CURRENT(io_end);
pgstat_report_wait_end();
@@ -1650,7 +1640,7 @@ lfc_writev(NRelFileInfo rinfo, ForkNumber forkNum, BlockNumber blkno,
FileCacheEntry *entry;
ssize_t rc;
bool found;
uint64 hash;
uint32 hash;
uint64 generation;
uint32 entry_offset;
int buf_offset = 0;
@@ -1663,7 +1653,6 @@ lfc_writev(NRelFileInfo rinfo, ForkNumber forkNum, BlockNumber blkno,
CriticalAssert(BufTagGetRelNumber(&tag) != InvalidRelFileNumber);
retry:
LWLockAcquire(lfc_lock, LW_EXCLUSIVE);
if (!LFC_ENABLED() || !lfc_ensure_opened())
@@ -1673,9 +1662,6 @@ lfc_writev(NRelFileInfo rinfo, ForkNumber forkNum, BlockNumber blkno,
}
generation = lfc_ctl->generation;
if (!freelist_prepare_pop())
goto retry;
/*
* For every chunk that has blocks we're interested in, we
* 1. get the chunk header
@@ -1689,7 +1675,7 @@ lfc_writev(NRelFileInfo rinfo, ForkNumber forkNum, BlockNumber blkno,
{
struct iovec iov[PG_IOV_MAX];
int chunk_offs = BLOCK_TO_CHUNK_OFF(blkno);
int blocks_in_chunk = Min(nblocks, BLOCKS_PER_CHUNK - chunk_offs);
int blocks_in_chunk = Min(nblocks, lfc_blocks_per_chunk - chunk_offs);
instr_time io_start, io_end;
ConditionVariable* cv;
@@ -1702,16 +1688,16 @@ lfc_writev(NRelFileInfo rinfo, ForkNumber forkNum, BlockNumber blkno,
}
tag.blockNum = blkno - chunk_offs;
hash = file_cache_hash_get_hash_value(lfc_hash, &tag);
hash = get_hash_value(lfc_hash, &tag);
cv = &lfc_ctl->cv[hash % N_COND_VARS];
entry = file_cache_hash_enter(lfc_hash, &tag, hash, &found);
entry = hash_search_with_hash_value(lfc_hash, &tag, hash, HASH_ENTER, &found);
/* Approximate working set for the blocks assumed in this entry */
for (int i = 0; i < blocks_in_chunk; i++)
{
tag.blockNum = blkno + i;
addSHLL(&lfc_ctl->wss_estimation, file_cache_hash_get_hash_value(lfc_hash, &tag));
addSHLL(&lfc_ctl->wss_estimation, hash_bytes((uint8_t const*)&tag, sizeof(tag)));
}
if (found)
@@ -1728,7 +1714,7 @@ lfc_writev(NRelFileInfo rinfo, ForkNumber forkNum, BlockNumber blkno,
}
else
{
if (!lfc_init_new_entry(entry))
if (!lfc_init_new_entry(entry, hash))
{
/*
* We can't process this chunk due to lack of space in LFC,
@@ -1777,7 +1763,7 @@ lfc_writev(NRelFileInfo rinfo, ForkNumber forkNum, BlockNumber blkno,
pgstat_report_wait_start(WAIT_EVENT_NEON_LFC_WRITE);
INSTR_TIME_SET_CURRENT(io_start);
rc = pwritev(lfc_desc, iov, blocks_in_chunk,
((off_t) entry_offset * BLOCKS_PER_CHUNK + chunk_offs) * BLCKSZ);
((off_t) entry_offset * lfc_blocks_per_chunk + chunk_offs) * BLCKSZ);
INSTR_TIME_SET_CURRENT(io_end);
pgstat_report_wait_end();
@@ -1837,140 +1823,6 @@ lfc_writev(NRelFileInfo rinfo, ForkNumber forkNum, BlockNumber blkno,
LWLockRelease(lfc_lock);
}
/**** freelist management ****/
/*
* Prerequisites:
* - The caller is holding 'lfc_lock'. XXX
*/
static bool
freelist_prepare_pop(void)
{
/*
* If the in-memory freelist is empty, but there are more blocks available, load them.
*
* TODO: if there
*/
if (lfc_ctl->num_free_pages == 0 && lfc_ctl->freelist_head != INVALID_OFFSET)
{
uint32 freelist_head;
FreeListChunk *freelist_chunk;
size_t bytes_read;
LWLockRelease(lfc_lock);
LWLockAcquire(lfc_freelist_lock, LW_EXCLUSIVE);
if (!(lfc_ctl->num_free_pages == 0 && lfc_ctl->freelist_head != INVALID_OFFSET))
{
/* someone else did the work for us while we were not holding the lock */
LWLockRelease(lfc_freelist_lock);
return false;
}
freelist_head = lfc_ctl->freelist_head;
freelist_chunk = palloc(BLOCKS_PER_CHUNK * BLCKSZ);
bytes_read = 0;
while (bytes_read < BLOCKS_PER_CHUNK * BLCKSZ)
{
ssize_t rc;
rc = pread(lfc_desc, freelist_chunk, BLOCKS_PER_CHUNK * BLCKSZ - bytes_read, (off_t) freelist_head * BLOCKS_PER_CHUNK * BLCKSZ + bytes_read);
if (rc < 0)
{
lfc_disable("read freelist page");
return false;
}
bytes_read += rc;
}
LWLockAcquire(lfc_lock, LW_EXCLUSIVE);
if (lfc_generation != lfc_ctl->generation)
{
LWLockRelease(lfc_lock);
return false;
}
Assert(lfc_ctl->freelist_head == freelist_head);
Assert(lfc_ctl->num_free_pages == 0);
lfc_ctl->freelist_head = freelist_chunk->next;
lfc_ctl->num_free_pages = freelist_chunk->num_free_pages;
memcpy(lfc_ctl->free_pages, freelist_chunk->free_pages, lfc_ctl->num_free_pages * sizeof(uint32));
pfree(freelist_chunk);
LWLockRelease(lfc_lock);
LWLockRelease(lfc_freelist_lock);
return false;
}
return true;
}
/*
* Prerequisites:
* - The caller is holding 'lfc_lock' and 'lfc_freelist_lock'.
*
* Returns 'false' on error.
*/
static bool
freelist_push(uint32 offset)
{
Assert(lfc_ctl->freelist_size == FREELIST_ENTRIES_PER_CHUNK);
if (lfc_ctl->num_free_pages == lfc_ctl->freelist_size)
{
FreeListChunk *freelist_chunk;
struct iovec iov;
ssize_t rc;
freelist_chunk = palloc(BLOCKS_PER_CHUNK * BLCKSZ);
/* write the existing entries to the chunk on disk */
freelist_chunk->next = lfc_ctl->freelist_head;
freelist_chunk->num_free_pages = lfc_ctl->num_free_pages;
memcpy(freelist_chunk->free_pages, lfc_ctl->free_pages, lfc_ctl->num_free_pages * sizeof(uint32));
/* Use the passed-in offset to hold the freelist chunk itself */
iov.iov_base = freelist_chunk;
iov.iov_len = BLOCKS_PER_CHUNK * BLCKSZ;
rc = pg_pwritev_with_retry(lfc_desc, &iov, 1, (off_t) offset * BLOCKS_PER_CHUNK * BLCKSZ);
pfree(freelist_chunk);
if (rc < 0)
return false;
lfc_ctl->freelist_head = offset;
lfc_ctl->num_free_pages = 0;
}
else
{
lfc_ctl->free_pages[lfc_ctl->num_free_pages] = offset;
lfc_ctl->num_free_pages++;
}
return true;
}
static uint32
freelist_pop(void)
{
uint32 result;
/* The caller should've checked that the list is not empty */
Assert(lfc_ctl->num_free_pages > 0);
result = lfc_ctl->free_pages[lfc_ctl->num_free_pages - 1];
lfc_ctl->num_free_pages--;
return result;
}
static bool
freelist_is_empty(void)
{
return lfc_ctl->num_free_pages == 0;
}
typedef struct
{
TupleDesc tupdesc;
@@ -2067,7 +1919,7 @@ neon_get_lfc_stats(PG_FUNCTION_ARGS)
break;
case 8:
key = "file_cache_chunk_size_pages";
value = BLOCKS_PER_CHUNK;
value = lfc_blocks_per_chunk;
break;
case 9:
key = "file_cache_chunks_pinned";
@@ -2138,6 +1990,7 @@ local_cache_pages(PG_FUNCTION_ARGS)
if (SRF_IS_FIRSTCALL())
{
HASH_SEQ_STATUS status;
FileCacheEntry *entry;
uint32 n_pages = 0;
@@ -2193,16 +2046,15 @@ local_cache_pages(PG_FUNCTION_ARGS)
if (LFC_ENABLED())
{
uint32 num_buckets = file_cache_hash_get_num_buckets(lfc_hash);
for (uint32 pos = 0; pos < num_buckets; pos++)
hash_seq_init(&status, lfc_hash);
while ((entry = hash_seq_search(&status)) != NULL)
{
entry = file_cache_hash_get_at_pos(lfc_hash, pos);
if (entry == NULL)
continue;
for (int i = 0; i < BLOCKS_PER_CHUNK; i++)
n_pages += GET_STATE(entry, i) == AVAILABLE;
/* Skip hole tags */
if (NInfoGetRelNumber(BufTagGetNRelFileInfo(entry->key)) != 0)
{
for (int i = 0; i < lfc_blocks_per_chunk; i++)
n_pages += GET_STATE(entry, i) == AVAILABLE;
}
}
}
}
@@ -2224,28 +2076,25 @@ local_cache_pages(PG_FUNCTION_ARGS)
* in the fctx->record structure.
*/
uint32 n = 0;
uint32 num_buckets = file_cache_hash_get_num_buckets(lfc_hash);
for (uint32 pos = 0; pos < num_buckets; pos++)
hash_seq_init(&status, lfc_hash);
while ((entry = hash_seq_search(&status)) != NULL)
{
entry = file_cache_hash_get_at_pos(lfc_hash, pos);
if (entry == NULL)
continue;
for (int i = 0; i < BLOCKS_PER_CHUNK; i++)
for (int i = 0; i < lfc_blocks_per_chunk; i++)
{
const BufferTag *key = file_cache_hash_get_key_for_entry(lfc_hash, entry);
if (GET_STATE(entry, i) == AVAILABLE)
if (NInfoGetRelNumber(BufTagGetNRelFileInfo(entry->key)) != 0)
{
fctx->record[n].pageoffs = entry->offset * BLOCKS_PER_CHUNK + i;
fctx->record[n].relfilenode = NInfoGetRelNumber(BufTagGetNRelFileInfo(*key));
fctx->record[n].reltablespace = NInfoGetSpcOid(BufTagGetNRelFileInfo(*key));
fctx->record[n].reldatabase = NInfoGetDbOid(BufTagGetNRelFileInfo(*key));
fctx->record[n].forknum = key->forkNum;
fctx->record[n].blocknum = key->blockNum + i;
fctx->record[n].accesscount = entry->access_count;
n += 1;
if (GET_STATE(entry, i) == AVAILABLE)
{
fctx->record[n].pageoffs = entry->offset * lfc_blocks_per_chunk + i;
fctx->record[n].relfilenode = NInfoGetRelNumber(BufTagGetNRelFileInfo(entry->key));
fctx->record[n].reltablespace = NInfoGetSpcOid(BufTagGetNRelFileInfo(entry->key));
fctx->record[n].reldatabase = NInfoGetDbOid(BufTagGetNRelFileInfo(entry->key));
fctx->record[n].forknum = entry->key.forkNum;
fctx->record[n].blocknum = entry->key.blockNum + i;
fctx->record[n].accesscount = entry->access_count;
n += 1;
}
}
}
}

View File

@@ -16,7 +16,6 @@
#if PG_MAJORVERSION_NUM >= 15
#include "access/xlogrecovery.h"
#endif
#include "executor/instrument.h"
#include "replication/logical.h"
#include "replication/logicallauncher.h"
#include "replication/slot.h"
@@ -34,7 +33,6 @@
#include "file_cache.h"
#include "neon.h"
#include "neon_lwlsncache.h"
#include "neon_perf_counters.h"
#include "control_plane_connector.h"
#include "logical_replication_monitor.h"
#include "unstable_extensions.h"
@@ -48,13 +46,6 @@ void _PG_init(void);
static int running_xacts_overflow_policy;
static bool monitor_query_exec_time = false;
static ExecutorStart_hook_type prev_ExecutorStart = NULL;
static ExecutorEnd_hook_type prev_ExecutorEnd = NULL;
static void neon_ExecutorStart(QueryDesc *queryDesc, int eflags);
static void neon_ExecutorEnd(QueryDesc *queryDesc);
#if PG_MAJORVERSION_NUM >= 16
static shmem_startup_hook_type prev_shmem_startup_hook;
@@ -479,16 +470,6 @@ _PG_init(void)
0,
NULL, NULL, NULL);
DefineCustomBoolVariable(
"neon.monitor_query_exec_time",
"Collect infortmation about query execution time",
NULL,
&monitor_query_exec_time,
false,
PGC_USERSET,
0,
NULL, NULL, NULL);
DefineCustomBoolVariable(
"neon.allow_replica_misconfig",
"Allow replica startup when some critical GUCs have smaller value than on primary node",
@@ -527,11 +508,6 @@ _PG_init(void)
EmitWarningsOnPlaceholders("neon");
ReportSearchPath();
prev_ExecutorStart = ExecutorStart_hook;
ExecutorStart_hook = neon_ExecutorStart;
prev_ExecutorEnd = ExecutorEnd_hook;
ExecutorEnd_hook = neon_ExecutorEnd;
}
PG_FUNCTION_INFO_V1(pg_cluster_size);
@@ -605,55 +581,3 @@ neon_shmem_startup_hook(void)
#endif
}
#endif
/*
* ExecutorStart hook: start up tracking if needed
*/
static void
neon_ExecutorStart(QueryDesc *queryDesc, int eflags)
{
if (prev_ExecutorStart)
prev_ExecutorStart(queryDesc, eflags);
else
standard_ExecutorStart(queryDesc, eflags);
if (monitor_query_exec_time)
{
/*
* Set up to track total elapsed time in ExecutorRun. Make sure the
* space is allocated in the per-query context so it will go away at
* ExecutorEnd.
*/
if (queryDesc->totaltime == NULL)
{
MemoryContext oldcxt;
oldcxt = MemoryContextSwitchTo(queryDesc->estate->es_query_cxt);
queryDesc->totaltime = InstrAlloc(1, INSTRUMENT_TIMER, false);
MemoryContextSwitchTo(oldcxt);
}
}
}
/*
* ExecutorEnd hook: store results if needed
*/
static void
neon_ExecutorEnd(QueryDesc *queryDesc)
{
if (monitor_query_exec_time && queryDesc->totaltime)
{
/*
* Make sure stats accumulation is done. (Note: it's okay if several
* levels of hook all do this.)
*/
InstrEndLoop(queryDesc->totaltime);
inc_query_time(queryDesc->totaltime->total*1000000); /* convert to usec */
}
if (prev_ExecutorEnd)
prev_ExecutorEnd(queryDesc);
else
standard_ExecutorEnd(queryDesc);
}

View File

@@ -71,27 +71,6 @@ inc_iohist(IOHistogram hist, uint64 latency_us)
hist->wait_us_count++;
}
static inline void
inc_qthist(QTHistogram hist, uint64 elapsed_us)
{
int lo = 0;
int hi = NUM_QT_BUCKETS - 1;
/* Find the right bucket with binary search */
while (lo < hi)
{
int mid = (lo + hi) / 2;
if (elapsed_us < qt_bucket_thresholds[mid])
hi = mid;
else
lo = mid + 1;
}
hist->elapsed_us_bucket[lo]++;
hist->elapsed_us_sum += elapsed_us;
hist->elapsed_us_count++;
}
/*
* Count a GetPage wait operation.
*/
@@ -119,13 +98,6 @@ inc_page_cache_write_wait(uint64 latency)
inc_iohist(&MyNeonCounters->file_cache_write_hist, latency);
}
void
inc_query_time(uint64 elapsed)
{
inc_qthist(&MyNeonCounters->query_time_hist, elapsed);
}
/*
* Support functions for the views, neon_backend_perf_counters and
* neon_perf_counters.
@@ -140,11 +112,11 @@ typedef struct
} metric_t;
static int
io_histogram_to_metrics(IOHistogram histogram,
metric_t *metrics,
const char *count,
const char *sum,
const char *bucket)
histogram_to_metrics(IOHistogram histogram,
metric_t *metrics,
const char *count,
const char *sum,
const char *bucket)
{
int i = 0;
uint64 bucket_accum = 0;
@@ -173,44 +145,10 @@ io_histogram_to_metrics(IOHistogram histogram,
return i;
}
static int
qt_histogram_to_metrics(QTHistogram histogram,
metric_t *metrics,
const char *count,
const char *sum,
const char *bucket)
{
int i = 0;
uint64 bucket_accum = 0;
metrics[i].name = count;
metrics[i].is_bucket = false;
metrics[i].value = (double) histogram->elapsed_us_count;
i++;
metrics[i].name = sum;
metrics[i].is_bucket = false;
metrics[i].value = (double) histogram->elapsed_us_sum / 1000000.0;
i++;
for (int bucketno = 0; bucketno < NUM_QT_BUCKETS; bucketno++)
{
uint64 threshold = qt_bucket_thresholds[bucketno];
bucket_accum += histogram->elapsed_us_bucket[bucketno];
metrics[i].name = bucket;
metrics[i].is_bucket = true;
metrics[i].bucket_le = (threshold == UINT64_MAX) ? INFINITY : ((double) threshold) / 1000000.0;
metrics[i].value = (double) bucket_accum;
i++;
}
return i;
}
static metric_t *
neon_perf_counters_to_metrics(neon_per_backend_counters *counters)
{
#define NUM_METRICS ((2 + NUM_IO_WAIT_BUCKETS) * 3 + (2 + NUM_QT_BUCKETS) + 12)
#define NUM_METRICS ((2 + NUM_IO_WAIT_BUCKETS) * 3 + 12)
metric_t *metrics = palloc((NUM_METRICS + 1) * sizeof(metric_t));
int i = 0;
@@ -221,10 +159,10 @@ neon_perf_counters_to_metrics(neon_per_backend_counters *counters)
i++; \
} while (false)
i += io_histogram_to_metrics(&counters->getpage_hist, &metrics[i],
"getpage_wait_seconds_count",
"getpage_wait_seconds_sum",
"getpage_wait_seconds_bucket");
i += histogram_to_metrics(&counters->getpage_hist, &metrics[i],
"getpage_wait_seconds_count",
"getpage_wait_seconds_sum",
"getpage_wait_seconds_bucket");
APPEND_METRIC(getpage_prefetch_requests_total);
APPEND_METRIC(getpage_sync_requests_total);
@@ -240,19 +178,14 @@ neon_perf_counters_to_metrics(neon_per_backend_counters *counters)
APPEND_METRIC(file_cache_hits_total);
i += io_histogram_to_metrics(&counters->file_cache_read_hist, &metrics[i],
"file_cache_read_wait_seconds_count",
"file_cache_read_wait_seconds_sum",
"file_cache_read_wait_seconds_bucket");
i += io_histogram_to_metrics(&counters->file_cache_write_hist, &metrics[i],
"file_cache_write_wait_seconds_count",
"file_cache_write_wait_seconds_sum",
"file_cache_write_wait_seconds_bucket");
i += qt_histogram_to_metrics(&counters->query_time_hist, &metrics[i],
"query_time_seconds_count",
"query_time_seconds_sum",
"query_time_seconds_bucket");
i += histogram_to_metrics(&counters->file_cache_read_hist, &metrics[i],
"file_cache_read_wait_seconds_count",
"file_cache_read_wait_seconds_sum",
"file_cache_read_wait_seconds_bucket");
i += histogram_to_metrics(&counters->file_cache_write_hist, &metrics[i],
"file_cache_write_wait_seconds_count",
"file_cache_write_wait_seconds_sum",
"file_cache_write_wait_seconds_bucket");
Assert(i == NUM_METRICS);
@@ -324,7 +257,7 @@ neon_get_backend_perf_counters(PG_FUNCTION_ARGS)
}
static inline void
io_histogram_merge_into(IOHistogram into, IOHistogram from)
histogram_merge_into(IOHistogram into, IOHistogram from)
{
into->wait_us_count += from->wait_us_count;
into->wait_us_sum += from->wait_us_sum;
@@ -332,15 +265,6 @@ io_histogram_merge_into(IOHistogram into, IOHistogram from)
into->wait_us_bucket[bucketno] += from->wait_us_bucket[bucketno];
}
static inline void
qt_histogram_merge_into(QTHistogram into, QTHistogram from)
{
into->elapsed_us_count += from->elapsed_us_count;
into->elapsed_us_sum += from->elapsed_us_sum;
for (int bucketno = 0; bucketno < NUM_QT_BUCKETS; bucketno++)
into->elapsed_us_bucket[bucketno] += from->elapsed_us_bucket[bucketno];
}
PG_FUNCTION_INFO_V1(neon_get_perf_counters);
Datum
neon_get_perf_counters(PG_FUNCTION_ARGS)
@@ -359,7 +283,7 @@ neon_get_perf_counters(PG_FUNCTION_ARGS)
{
neon_per_backend_counters *counters = &neon_per_backend_counters_shared[procno];
io_histogram_merge_into(&totals.getpage_hist, &counters->getpage_hist);
histogram_merge_into(&totals.getpage_hist, &counters->getpage_hist);
totals.getpage_prefetch_requests_total += counters->getpage_prefetch_requests_total;
totals.getpage_sync_requests_total += counters->getpage_sync_requests_total;
totals.getpage_prefetch_misses_total += counters->getpage_prefetch_misses_total;
@@ -370,13 +294,13 @@ neon_get_perf_counters(PG_FUNCTION_ARGS)
totals.pageserver_open_requests += counters->pageserver_open_requests;
totals.getpage_prefetches_buffered += counters->getpage_prefetches_buffered;
totals.file_cache_hits_total += counters->file_cache_hits_total;
histogram_merge_into(&totals.file_cache_read_hist, &counters->file_cache_read_hist);
histogram_merge_into(&totals.file_cache_write_hist, &counters->file_cache_write_hist);
totals.compute_getpage_stuck_requests_total += counters->compute_getpage_stuck_requests_total;
totals.compute_getpage_max_inflight_stuck_time_ms = Max(
totals.compute_getpage_max_inflight_stuck_time_ms,
counters->compute_getpage_max_inflight_stuck_time_ms);
io_histogram_merge_into(&totals.file_cache_read_hist, &counters->file_cache_read_hist);
io_histogram_merge_into(&totals.file_cache_write_hist, &counters->file_cache_write_hist);
qt_histogram_merge_into(&totals.query_time_hist, &counters->query_time_hist);
}
metrics = neon_perf_counters_to_metrics(&totals);

View File

@@ -36,28 +36,6 @@ typedef struct IOHistogramData
typedef IOHistogramData *IOHistogram;
static const uint64 qt_bucket_thresholds[] = {
2, 3, 6, 10, /* 0 us - 10 us */
20, 30, 60, 100, /* 10 us - 100 us */
200, 300, 600, 1000, /* 100 us - 1 ms */
2000, 3000, 6000, 10000, /* 1 ms - 10 ms */
20000, 30000, 60000, 100000, /* 10 ms - 100 ms */
200000, 300000, 600000, 1000000, /* 100 ms - 1 s */
2000000, 3000000, 6000000, 10000000, /* 1 s - 10 s */
20000000, 30000000, 60000000, 100000000, /* 10 s - 100 s */
UINT64_MAX,
};
#define NUM_QT_BUCKETS (lengthof(qt_bucket_thresholds))
typedef struct QTHistogramData
{
uint64 elapsed_us_count;
uint64 elapsed_us_sum;
uint64 elapsed_us_bucket[NUM_QT_BUCKETS];
} QTHistogramData;
typedef QTHistogramData *QTHistogram;
typedef struct
{
/*
@@ -149,11 +127,6 @@ typedef struct
/* LFC I/O time buckets */
IOHistogramData file_cache_read_hist;
IOHistogramData file_cache_write_hist;
/*
* Histogram of query execution time.
*/
QTHistogramData query_time_hist;
} neon_per_backend_counters;
/* Pointer to the shared memory array of neon_per_backend_counters structs */
@@ -176,7 +149,6 @@ extern neon_per_backend_counters *neon_per_backend_counters_shared;
extern void inc_getpage_wait(uint64 latency);
extern void inc_page_cache_read_wait(uint64 latency);
extern void inc_page_cache_write_wait(uint64 latency);
extern void inc_query_time(uint64 elapsed);
extern Size NeonPerfCountersShmemSize(void);
extern void NeonPerfCountersShmemInit(void);

View File

@@ -5,7 +5,6 @@
#include "funcapi.h"
#include "miscadmin.h"
#include "access/xlog.h"
#include "utils/tuplestore.h"
#include "neon_pgversioncompat.h"
@@ -42,12 +41,5 @@ InitMaterializedSRF(FunctionCallInfo fcinfo, bits32 flags)
rsinfo->setDesc = stored_tupdesc;
MemoryContextSwitchTo(old_context);
}
TimeLineID GetWALInsertionTimeLine(void)
{
return ThisTimeLineID + 1;
}
#endif

View File

@@ -162,7 +162,6 @@ InitBufferTag(BufferTag *tag, const RelFileNode *rnode,
#if PG_MAJORVERSION_NUM < 15
extern void InitMaterializedSRF(FunctionCallInfo fcinfo, bits32 flags);
extern TimeLineID GetWALInsertionTimeLine(void);
#endif
#endif /* NEON_PGVERSIONCOMPAT_H */

View File

@@ -69,7 +69,6 @@ struct NeonWALReader
WALSegmentContext segcxt;
WALOpenSegment seg;
int wre_errno;
TimeLineID local_active_tlid;
/* Explains failure to read, static for simplicity. */
char err_msg[NEON_WALREADER_ERR_MSG_LEN];
@@ -107,7 +106,7 @@ struct NeonWALReader
/* palloc and initialize NeonWALReader */
NeonWALReader *
NeonWALReaderAllocate(int wal_segment_size, XLogRecPtr available_lsn, char *log_prefix, TimeLineID tlid)
NeonWALReaderAllocate(int wal_segment_size, XLogRecPtr available_lsn, char *log_prefix)
{
NeonWALReader *reader;
@@ -119,7 +118,6 @@ NeonWALReaderAllocate(int wal_segment_size, XLogRecPtr available_lsn, char *log_
MemoryContextAllocZero(TopMemoryContext, sizeof(NeonWALReader));
reader->available_lsn = available_lsn;
reader->local_active_tlid = tlid;
reader->seg.ws_file = -1;
reader->seg.ws_segno = 0;
reader->seg.ws_tli = 0;
@@ -579,17 +577,6 @@ NeonWALReaderIsRemConnEstablished(NeonWALReader *state)
return state->rem_state == RS_ESTABLISHED;
}
/*
* Whether remote connection is established. Once this is done, until successful
* local read or error socket is stable and user can update socket events
* instead of readding it each time.
*/
TimeLineID
NeonWALReaderLocalActiveTimeLineID(NeonWALReader *state)
{
return state->local_active_tlid;
}
/*
* Returns events user should wait on connection socket or 0 if remote
* connection is not active.

View File

@@ -19,10 +19,9 @@ typedef enum
NEON_WALREAD_ERROR,
} NeonWALReadResult;
extern NeonWALReader *NeonWALReaderAllocate(int wal_segment_size, XLogRecPtr available_lsn, char *log_prefix, TimeLineID tlid);
extern NeonWALReader *NeonWALReaderAllocate(int wal_segment_size, XLogRecPtr available_lsn, char *log_prefix);
extern void NeonWALReaderFree(NeonWALReader *state);
extern void NeonWALReaderResetRemote(NeonWALReader *state);
extern TimeLineID NeonWALReaderLocalActiveTimeLineID(NeonWALReader *state);
extern NeonWALReadResult NeonWALRead(NeonWALReader *state, char *buf, XLogRecPtr startptr, Size count, TimeLineID tli);
extern pgsocket NeonWALReaderSocket(NeonWALReader *state);
extern uint32 NeonWALReaderEvents(NeonWALReader *state);

View File

@@ -98,7 +98,6 @@ WalProposerCreate(WalProposerConfig *config, walproposer_api api)
wp = palloc0(sizeof(WalProposer));
wp->config = config;
wp->api = api;
wp->localTimeLineID = config->pgTimeline;
wp->state = WPS_COLLECTING_TERMS;
wp->mconf.generation = INVALID_GENERATION;
wp->mconf.members.len = 0;
@@ -120,10 +119,6 @@ WalProposerCreate(WalProposerConfig *config, walproposer_api api)
{
wp_log(FATAL, "failed to parse neon.safekeepers generation number: %m");
}
if (*endptr != ':')
{
wp_log(FATAL, "failed to parse neon.safekeepers: no colon after generation");
}
/* Skip past : to the first hostname. */
host = endptr + 1;
}
@@ -1385,7 +1380,7 @@ ProcessPropStartPos(WalProposer *wp)
* we must bail out, as clog and other non rel data is inconsistent.
*/
walprop_shared = wp->api.get_shmem_state(wp);
if (!wp->config->syncSafekeepers && !walprop_shared->replica_promote)
if (!wp->config->syncSafekeepers)
{
/*
* Basebackup LSN always points to the beginning of the record (not

View File

@@ -391,7 +391,6 @@ typedef struct WalproposerShmemState
/* last feedback from each shard */
PageserverFeedback shard_ps_feedback[MAX_SHARDS];
int num_shards;
bool replica_promote;
/* aggregated feedback with min LSNs across shards */
PageserverFeedback min_ps_feedback;
@@ -807,9 +806,6 @@ typedef struct WalProposer
/* Safekeepers walproposer is connecting to. */
Safekeeper safekeeper[MAX_SAFEKEEPERS];
/* Current local TimeLineId in use */
TimeLineID localTimeLineID;
/* WAL has been generated up to this point */
XLogRecPtr availableLsn;

View File

@@ -35,7 +35,6 @@
#include "storage/proc.h"
#include "storage/ipc.h"
#include "storage/lwlock.h"
#include "storage/pg_shmem.h"
#include "storage/shmem.h"
#include "storage/spin.h"
#include "tcop/tcopprot.h"
@@ -160,19 +159,12 @@ WalProposerMain(Datum main_arg)
{
WalProposer *wp;
if (*wal_acceptors_list == '\0')
{
wpg_log(WARNING, "Safekeepers list is empty");
return;
}
init_walprop_config(false);
walprop_pg_init_bgworker();
am_walproposer = true;
walprop_pg_load_libpqwalreceiver();
wp = WalProposerCreate(&walprop_config, walprop_pg);
wp->localTimeLineID = GetWALInsertionTimeLine();
wp->last_reconnect_attempt = walprop_pg_get_current_timestamp(wp);
walprop_pg_init_walsender();
@@ -280,30 +272,6 @@ split_safekeepers_list(char *safekeepers_list, char *safekeepers[])
return n_safekeepers;
}
static char *split_off_safekeepers_generation(char *safekeepers_list, uint32 *generation)
{
char *endptr;
if (strncmp(safekeepers_list, "g#", 2) != 0)
{
return safekeepers_list;
}
else
{
errno = 0;
*generation = strtoul(safekeepers_list + 2, &endptr, 10);
if (errno != 0)
{
wp_log(FATAL, "failed to parse neon.safekeepers generation number: %m");
}
if (*endptr != ':')
{
wp_log(FATAL, "failed to parse neon.safekeepers: no colon after generation");
}
return endptr + 1;
}
}
/*
* Accept two coma-separated strings with list of safekeeper host:port addresses.
* Split them into arrays and return false if two sets do not match, ignoring the order.
@@ -315,16 +283,6 @@ safekeepers_cmp(char *old, char *new)
char *safekeepers_new[MAX_SAFEKEEPERS];
int len_old = 0;
int len_new = 0;
uint32 gen_old = INVALID_GENERATION;
uint32 gen_new = INVALID_GENERATION;
old = split_off_safekeepers_generation(old, &gen_old);
new = split_off_safekeepers_generation(new, &gen_new);
if (gen_old != gen_new)
{
return false;
}
len_old = split_safekeepers_list(old, safekeepers_old);
len_new = split_safekeepers_list(new, safekeepers_new);
@@ -358,9 +316,6 @@ assign_neon_safekeepers(const char *newval, void *extra)
char *newval_copy;
char *oldval;
if (newval && *newval != '\0' && UsedShmemSegAddr && walprop_shared && RecoveryInProgress())
walprop_shared->replica_promote = true;
if (!am_walproposer)
return;
@@ -551,15 +506,16 @@ BackpressureThrottlingTime(void)
/*
* Register a background worker proposing WAL to wal acceptors.
* We start walproposer bgworker even for replicas in order to support possible replica promotion.
* When pg_promote() function is called, then walproposer bgworker registered with BgWorkerStart_RecoveryFinished
* is automatically launched when promotion is completed.
*/
static void
walprop_register_bgworker(void)
{
BackgroundWorker bgw;
/* If no wal acceptors are specified, don't start the background worker. */
if (*wal_acceptors_list == '\0')
return;
memset(&bgw, 0, sizeof(bgw));
bgw.bgw_flags = BGWORKER_SHMEM_ACCESS;
bgw.bgw_start_time = BgWorkerStart_RecoveryFinished;
@@ -1336,7 +1292,9 @@ StartProposerReplication(WalProposer *wp, StartReplicationCmd *cmd)
#if PG_VERSION_NUM < 150000
if (ThisTimeLineID == 0)
ThisTimeLineID = 1;
ereport(ERROR,
(errcode(ERRCODE_OBJECT_NOT_IN_PREREQUISITE_STATE),
errmsg("IDENTIFY_SYSTEM has not been run before START_REPLICATION")));
#endif
/*
@@ -1550,7 +1508,7 @@ walprop_pg_wal_reader_allocate(Safekeeper *sk)
snprintf(log_prefix, sizeof(log_prefix), WP_LOG_PREFIX "sk %s:%s nwr: ", sk->host, sk->port);
Assert(!sk->xlogreader);
sk->xlogreader = NeonWALReaderAllocate(wal_segment_size, sk->wp->propTermStartLsn, log_prefix, sk->wp->localTimeLineID);
sk->xlogreader = NeonWALReaderAllocate(wal_segment_size, sk->wp->propTermStartLsn, log_prefix);
if (sk->xlogreader == NULL)
wpg_log(FATAL, "failed to allocate xlog reader");
}
@@ -1564,7 +1522,7 @@ walprop_pg_wal_read(Safekeeper *sk, char *buf, XLogRecPtr startptr, Size count,
buf,
startptr,
count,
sk->wp->localTimeLineID);
walprop_pg_get_timeline_id());
if (res == NEON_WALREAD_SUCCESS)
{

View File

@@ -111,7 +111,7 @@ NeonWALPageRead(
readBuf,
targetPagePtr,
count,
NeonWALReaderLocalActiveTimeLineID(wal_reader));
walprop_pg_get_timeline_id());
if (res == NEON_WALREAD_SUCCESS)
{
@@ -202,7 +202,7 @@ NeonOnDemandXLogReaderRoutines(XLogReaderRoutine *xlr)
{
elog(ERROR, "unable to start walsender when basebackupLsn is 0");
}
wal_reader = NeonWALReaderAllocate(wal_segment_size, basebackupLsn, "[walsender] ", 1);
wal_reader = NeonWALReaderAllocate(wal_segment_size, basebackupLsn, "[walsender] ");
}
xlr->page_read = NeonWALPageRead;
xlr->segment_open = NeonWALReadSegmentOpen;

View File

@@ -17,23 +17,35 @@ pub(super) async fn authenticate(
config: &'static AuthenticationConfig,
secret: AuthSecret,
) -> auth::Result<ComputeCredentials> {
let flow = AuthFlow::new(client);
let scram_keys = match secret {
#[cfg(any(test, feature = "testing"))]
AuthSecret::Md5(_) => {
debug!("auth endpoint chooses MD5");
return Err(auth::AuthError::MalformedPassword("MD5 not supported"));
return Err(auth::AuthError::bad_auth_method("MD5"));
}
AuthSecret::Scram(secret) => {
debug!("auth endpoint chooses SCRAM");
let scram = auth::Scram(&secret, ctx);
let auth_outcome = tokio::time::timeout(
config.scram_protocol_timeout,
AuthFlow::new(client, auth::Scram(&secret, ctx)).authenticate(),
async {
flow.begin(scram).await.map_err(|error| {
warn!(?error, "error sending scram acknowledgement");
error
})?.authenticate().await.map_err(|error| {
warn!(?error, "error processing scram messages");
error
})
}
)
.await
.inspect_err(|_| warn!("error processing scram messages error = authentication timed out, execution time exceeded {} seconds", config.scram_protocol_timeout.as_secs()))
.map_err(auth::AuthError::user_timeout)?
.inspect_err(|error| warn!(?error, "error processing scram messages"))?;
.map_err(|e| {
warn!("error processing scram messages error = authentication timed out, execution time exceeded {} seconds", config.scram_protocol_timeout.as_secs());
auth::AuthError::user_timeout(e)
})??;
let client_key = match auth_outcome {
sasl::Outcome::Success(key) => key,

View File

@@ -2,6 +2,7 @@ use std::fmt;
use async_trait::async_trait;
use postgres_client::config::SslMode;
use pq_proto::BeMessage as Be;
use thiserror::Error;
use tokio::io::{AsyncRead, AsyncWrite};
use tracing::{info, info_span};
@@ -15,9 +16,8 @@ use crate::context::RequestContext;
use crate::control_plane::client::cplane_proxy_v1;
use crate::control_plane::{self, CachedNodeInfo, NodeInfo};
use crate::error::{ReportableError, UserFacingError};
use crate::pglb::connect_compute::ComputeConnectBackend;
use crate::pqproto::BeMessage;
use crate::proxy::NeonOptions;
use crate::proxy::connect_compute::ComputeConnectBackend;
use crate::stream::PqStream;
use crate::types::RoleName;
use crate::{auth, compute, waiters};
@@ -154,13 +154,11 @@ async fn authenticate(
// Give user a URL to spawn a new database.
info!(parent: &span, "sending the auth URL to the user");
client.write_message(BeMessage::AuthenticationOk);
client.write_message(BeMessage::ParameterStatus {
name: b"client_encoding",
value: b"UTF8",
});
client.write_message(BeMessage::NoticeResponse(&greeting));
client.flush().await?;
client
.write_message_noflush(&Be::AuthenticationOk)?
.write_message_noflush(&Be::CLIENT_ENCODING)?
.write_message(&Be::NoticeResponse(&greeting))
.await?;
// Wait for console response via control plane (see `mgmt`).
info!(parent: &span, "waiting for console's reply...");
@@ -190,7 +188,7 @@ async fn authenticate(
}
}
client.write_message(BeMessage::NoticeResponse("Connecting to database."));
client.write_message_noflush(&Be::NoticeResponse("Connecting to database."))?;
// This config should be self-contained, because we won't
// take username or dbname from client's startup message.

View File

@@ -24,25 +24,23 @@ pub(crate) async fn authenticate_cleartext(
debug!("cleartext auth flow override is enabled, proceeding");
ctx.set_auth_method(crate::context::AuthMethod::Cleartext);
// pause the timer while we communicate with the client
let paused = ctx.latency_timer_pause(crate::metrics::Waiting::Client);
let ep = EndpointIdInt::from(&info.endpoint);
let auth_flow = AuthFlow::new(
client,
auth::CleartextPassword {
let auth_flow = AuthFlow::new(client)
.begin(auth::CleartextPassword {
secret,
endpoint: ep,
pool: config.thread_pool.clone(),
},
);
let auth_outcome = {
// pause the timer while we communicate with the client
let _paused = ctx.latency_timer_pause(crate::metrics::Waiting::Client);
// cleartext auth is only allowed to the ws/http protocol.
// If we're here, we already received the password in the first message.
// Scram protocol will be executed on the proxy side.
auth_flow.authenticate().await?
};
})
.await?;
drop(paused);
// cleartext auth is only allowed to the ws/http protocol.
// If we're here, we already received the password in the first message.
// Scram protocol will be executed on the proxy side.
let auth_outcome = auth_flow.authenticate().await?;
let keys = match auth_outcome {
sasl::Outcome::Success(key) => key,
@@ -69,7 +67,9 @@ pub(crate) async fn password_hack_no_authentication(
// pause the timer while we communicate with the client
let _paused = ctx.latency_timer_pause(crate::metrics::Waiting::Client);
let payload = AuthFlow::new(client, auth::PasswordHack)
let payload = AuthFlow::new(client)
.begin(auth::PasswordHack)
.await?
.get_password()
.await?;

View File

@@ -4,16 +4,19 @@ mod hacks;
pub mod jwt;
pub mod local;
use std::net::IpAddr;
use std::sync::Arc;
pub use console_redirect::ConsoleRedirectBackend;
pub(crate) use console_redirect::ConsoleRedirectError;
use ipnet::{Ipv4Net, Ipv6Net};
use local::LocalBackend;
use postgres_client::config::AuthKeys;
use serde::{Deserialize, Serialize};
use tokio::io::{AsyncRead, AsyncWrite};
use tracing::{debug, info};
use tracing::{debug, info, warn};
use crate::auth::credentials::check_peer_addr_is_in_list;
use crate::auth::{self, AuthError, ComputeUserInfoMaybeEndpoint, validate_password_and_exchange};
use crate::cache::Cached;
use crate::config::AuthenticationConfig;
@@ -21,14 +24,15 @@ use crate::context::RequestContext;
use crate::control_plane::client::ControlPlaneClient;
use crate::control_plane::errors::GetAuthInfoError;
use crate::control_plane::{
self, AccessBlockerFlags, AuthSecret, CachedNodeInfo, ControlPlaneApi, EndpointAccessControl,
RoleAccessControl,
self, AccessBlockerFlags, AuthSecret, CachedAccessBlockerFlags, CachedAllowedIps,
CachedAllowedVpcEndpointIds, CachedNodeInfo, CachedRoleSecret, ControlPlaneApi,
};
use crate::intern::EndpointIdInt;
use crate::pglb::connect_compute::ComputeConnectBackend;
use crate::pqproto::BeMessage;
use crate::metrics::Metrics;
use crate::protocol2::ConnectionInfoExtra;
use crate::proxy::NeonOptions;
use crate::rate_limiter::EndpointRateLimiter;
use crate::proxy::connect_compute::ComputeConnectBackend;
use crate::rate_limiter::{BucketRateLimiter, EndpointRateLimiter};
use crate::stream::Stream;
use crate::types::{EndpointCacheKey, EndpointId, RoleName};
use crate::{scram, stream};
@@ -131,16 +135,6 @@ impl<'a, T> Backend<'a, T> {
}
}
}
impl<'a, T, E> Backend<'a, Result<T, E>> {
/// Very similar to [`std::option::Option::transpose`].
/// This is most useful for error handling.
pub(crate) fn transpose(self) -> Result<Backend<'a, T>, E> {
match self {
Self::ControlPlane(c, x) => x.map(|x| Backend::ControlPlane(c, x)),
Self::Local(l) => Ok(Backend::Local(l)),
}
}
}
pub(crate) struct ComputeCredentials {
pub(crate) info: ComputeUserInfo,
@@ -194,6 +188,78 @@ impl TryFrom<ComputeUserInfoMaybeEndpoint> for ComputeUserInfo {
}
}
#[derive(PartialEq, PartialOrd, Hash, Eq, Ord, Debug, Copy, Clone)]
pub struct MaskedIp(IpAddr);
impl MaskedIp {
fn new(value: IpAddr, prefix: u8) -> Self {
match value {
IpAddr::V4(v4) => Self(IpAddr::V4(
Ipv4Net::new(v4, prefix).map_or(v4, |x| x.trunc().addr()),
)),
IpAddr::V6(v6) => Self(IpAddr::V6(
Ipv6Net::new(v6, prefix).map_or(v6, |x| x.trunc().addr()),
)),
}
}
}
// This can't be just per IP because that would limit some PaaS that share IP addresses
pub type AuthRateLimiter = BucketRateLimiter<(EndpointIdInt, MaskedIp)>;
impl AuthenticationConfig {
pub(crate) fn check_rate_limit(
&self,
ctx: &RequestContext,
secret: AuthSecret,
endpoint: &EndpointId,
is_cleartext: bool,
) -> auth::Result<AuthSecret> {
// we have validated the endpoint exists, so let's intern it.
let endpoint_int = EndpointIdInt::from(endpoint.normalize());
// only count the full hash count if password hack or websocket flow.
// in other words, if proxy needs to run the hashing
let password_weight = if is_cleartext {
match &secret {
#[cfg(any(test, feature = "testing"))]
AuthSecret::Md5(_) => 1,
AuthSecret::Scram(s) => s.iterations + 1,
}
} else {
// validating scram takes just 1 hmac_sha_256 operation.
1
};
let limit_not_exceeded = self.rate_limiter.check(
(
endpoint_int,
MaskedIp::new(ctx.peer_addr(), self.rate_limit_ip_subnet),
),
password_weight,
);
if !limit_not_exceeded {
warn!(
enabled = self.rate_limiter_enabled,
"rate limiting authentication"
);
Metrics::get().proxy.requests_auth_rate_limits_total.inc();
Metrics::get()
.proxy
.endpoints_auth_rate_limits
.get_metric()
.measure(endpoint);
if self.rate_limiter_enabled {
return Err(auth::AuthError::too_many_connections());
}
}
Ok(secret)
}
}
/// True to its name, this function encapsulates our current auth trade-offs.
/// Here, we choose the appropriate auth flow based on circumstances.
///
@@ -222,27 +288,52 @@ async fn auth_quirks(
debug!("fetching authentication info and allowlists");
let access_controls = api
.get_endpoint_access_control(ctx, &info.endpoint, &info.user)
.await?;
// check allowed list
if config.ip_allowlist_check_enabled {
let allowed_ips = api.get_allowed_ips(ctx, &info).await?;
if !check_peer_addr_is_in_list(&ctx.peer_addr(), &allowed_ips) {
return Err(auth::AuthError::ip_address_not_allowed(ctx.peer_addr()));
}
}
access_controls.check(
ctx,
config.ip_allowlist_check_enabled,
config.is_vpc_acccess_proxy,
)?;
// check if a VPC endpoint ID is coming in and if yes, if it's allowed
let access_blocks = api.get_block_public_or_vpc_access(ctx, &info).await?;
if config.is_vpc_acccess_proxy {
if access_blocks.vpc_access_blocked {
return Err(AuthError::NetworkNotAllowed);
}
let endpoint = EndpointIdInt::from(&info.endpoint);
let rate_limit_config = None;
if !endpoint_rate_limiter.check(endpoint, rate_limit_config, 1) {
let incoming_vpc_endpoint_id = match ctx.extra() {
None => return Err(AuthError::MissingEndpointName),
Some(ConnectionInfoExtra::Aws { vpce_id }) => vpce_id.to_string(),
Some(ConnectionInfoExtra::Azure { link_id }) => link_id.to_string(),
};
let allowed_vpc_endpoint_ids = api.get_allowed_vpc_endpoint_ids(ctx, &info).await?;
// TODO: For now an empty VPC endpoint ID list means all are allowed. We should replace that.
if !allowed_vpc_endpoint_ids.is_empty()
&& !allowed_vpc_endpoint_ids.contains(&incoming_vpc_endpoint_id)
{
return Err(AuthError::vpc_endpoint_id_not_allowed(
incoming_vpc_endpoint_id,
));
}
} else if access_blocks.public_access_blocked {
return Err(AuthError::NetworkNotAllowed);
}
if !endpoint_rate_limiter.check(info.endpoint.clone().into(), 1) {
return Err(AuthError::too_many_connections());
}
let role_access = api
.get_role_access_control(ctx, &info.endpoint, &info.user)
.await?;
let cached_secret = api.get_role_secret(ctx, &info).await?;
let (cached_entry, secret) = cached_secret.take_value();
let secret = if let Some(secret) = role_access.secret {
secret
let secret = if let Some(secret) = secret {
config.check_rate_limit(
ctx,
secret,
&info.endpoint,
unauthenticated_password.is_some() || allow_cleartext,
)?
} else {
// If we don't have an authentication secret, we mock one to
// prevent malicious probing (possible due to missing protocol steps).
@@ -263,7 +354,13 @@ async fn auth_quirks(
.await
{
Ok(keys) => Ok(keys),
Err(e) => Err(e),
Err(e) => {
if e.is_password_failed() {
// The password could have been changed, so we invalidate the cache.
cached_entry.invalidate();
}
Err(e)
}
}
}
@@ -290,7 +387,7 @@ async fn authenticate_with_secret(
};
// we have authenticated the password
client.write_message(BeMessage::AuthenticationOk);
client.write_message_noflush(&pq_proto::BeMessage::AuthenticationOk)?;
return Ok(ComputeCredentials { info, keys });
}
@@ -308,71 +405,39 @@ async fn authenticate_with_secret(
classic::authenticate(ctx, info, client, config, secret).await
}
impl<'a> Backend<'a, ComputeUserInfoMaybeEndpoint> {
/// Get username from the credentials.
pub(crate) fn get_user(&self) -> &str {
match self {
Self::ControlPlane(_, user_info) => &user_info.user,
Self::Local(_) => "local",
}
}
impl ControlPlaneClient {
/// Authenticate the client via the requested backend, possibly using credentials.
#[tracing::instrument(fields(allow_cleartext = allow_cleartext), skip_all)]
pub(crate) async fn authenticate(
self,
&self,
ctx: &RequestContext,
client: &mut stream::PqStream<Stream<impl AsyncRead + AsyncWrite + Unpin>>,
user_info: ComputeUserInfoMaybeEndpoint,
allow_cleartext: bool,
config: &'static AuthenticationConfig,
endpoint_rate_limiter: Arc<EndpointRateLimiter>,
) -> auth::Result<Backend<'a, ComputeCredentials>> {
let res = match self {
Self::ControlPlane(api, user_info) => {
debug!(
user = &*user_info.user,
project = user_info.endpoint(),
"performing authentication using the console"
);
) -> auth::Result<ComputeCredentials> {
debug!(
user = &*user_info.user,
project = user_info.endpoint(),
"performing authentication using the console"
);
let auth_res = auth_quirks(
ctx,
&*api,
user_info.clone(),
client,
allow_cleartext,
config,
endpoint_rate_limiter,
)
.await;
match auth_res {
Ok(credentials) => Ok(Backend::ControlPlane(api, credentials)),
Err(e) => {
// The password could have been changed, so we invalidate the cache.
// We should only invalidate the cache if the TTL might have expired.
if e.is_password_failed() {
#[allow(irrefutable_let_patterns)]
if let ControlPlaneClient::ProxyV1(api) = &*api {
if let Some(ep) = &user_info.endpoint_id {
api.caches
.project_info
.maybe_invalidate_role_secret(ep, &user_info.user);
}
}
}
Err(e)
}
}
}
Self::Local(_) => {
return Err(auth::AuthError::bad_auth_method("invalid for local proxy"));
}
};
let credentials = auth_quirks(
ctx,
self,
user_info,
client,
allow_cleartext,
config,
endpoint_rate_limiter,
)
.await?;
// TODO: replace with some metric
info!("user successfully authenticated");
res
Ok(credentials)
}
}
@@ -380,30 +445,44 @@ impl Backend<'_, ComputeUserInfo> {
pub(crate) async fn get_role_secret(
&self,
ctx: &RequestContext,
) -> Result<RoleAccessControl, GetAuthInfoError> {
) -> Result<CachedRoleSecret, GetAuthInfoError> {
match self {
Self::ControlPlane(api, user_info) => {
api.get_role_access_control(ctx, &user_info.endpoint, &user_info.user)
.await
}
Self::Local(_) => Ok(RoleAccessControl { secret: None }),
Self::ControlPlane(api, user_info) => api.get_role_secret(ctx, user_info).await,
Self::Local(_) => Ok(Cached::new_uncached(None)),
}
}
pub(crate) async fn get_endpoint_access_control(
pub(crate) async fn get_allowed_ips(
&self,
ctx: &RequestContext,
) -> Result<EndpointAccessControl, GetAuthInfoError> {
) -> Result<CachedAllowedIps, GetAuthInfoError> {
match self {
Self::ControlPlane(api, user_info) => api.get_allowed_ips(ctx, user_info).await,
Self::Local(_) => Ok(Cached::new_uncached(Arc::new(vec![]))),
}
}
pub(crate) async fn get_allowed_vpc_endpoint_ids(
&self,
ctx: &RequestContext,
) -> Result<CachedAllowedVpcEndpointIds, GetAuthInfoError> {
match self {
Self::ControlPlane(api, user_info) => {
api.get_endpoint_access_control(ctx, &user_info.endpoint, &user_info.user)
.await
api.get_allowed_vpc_endpoint_ids(ctx, user_info).await
}
Self::Local(_) => Ok(EndpointAccessControl {
allowed_ips: Arc::new(vec![]),
allowed_vpce: Arc::new(vec![]),
flags: AccessBlockerFlags::default(),
}),
Self::Local(_) => Ok(Cached::new_uncached(Arc::new(vec![]))),
}
}
pub(crate) async fn get_block_public_or_vpc_access(
&self,
ctx: &RequestContext,
) -> Result<CachedAccessBlockerFlags, GetAuthInfoError> {
match self {
Self::ControlPlane(api, user_info) => {
api.get_block_public_or_vpc_access(ctx, user_info).await
}
Self::Local(_) => Ok(Cached::new_uncached(AccessBlockerFlags::default())),
}
}
}
@@ -428,11 +507,32 @@ impl ComputeConnectBackend for Backend<'_, ComputeCredentials> {
}
}
pub struct ControlPlaneWakeCompute<'a> {
pub cplane: &'a ControlPlaneClient,
pub creds: ComputeCredentials,
}
#[async_trait::async_trait]
impl ComputeConnectBackend for ControlPlaneWakeCompute<'_> {
async fn wake_compute(
&self,
ctx: &RequestContext,
) -> Result<CachedNodeInfo, control_plane::errors::WakeComputeError> {
self.cplane.wake_compute(ctx, &self.creds.info).await
}
fn get_keys(&self) -> &ComputeCredentialKeys {
&self.creds.keys
}
}
#[cfg(test)]
mod tests {
#![allow(clippy::unimplemented, clippy::unwrap_used)]
use std::net::IpAddr;
use std::sync::Arc;
use std::time::Duration;
use bytes::BytesMut;
use control_plane::AuthSecret;
@@ -442,17 +542,20 @@ mod tests {
use postgres_protocol::message::backend::Message as PgMessage;
use postgres_protocol::message::frontend;
use tokio::io::{AsyncRead, AsyncReadExt, AsyncWriteExt};
use tokio_util::task::TaskTracker;
use super::auth_quirks;
use super::jwt::JwkCache;
use super::{AuthRateLimiter, auth_quirks};
use crate::auth::backend::MaskedIp;
use crate::auth::{ComputeUserInfoMaybeEndpoint, IpPattern};
use crate::config::AuthenticationConfig;
use crate::context::RequestContext;
use crate::control_plane::{
self, AccessBlockerFlags, CachedNodeInfo, EndpointAccessControl, RoleAccessControl,
self, AccessBlockerFlags, CachedAccessBlockerFlags, CachedAllowedIps,
CachedAllowedVpcEndpointIds, CachedNodeInfo, CachedRoleSecret,
};
use crate::proxy::NeonOptions;
use crate::rate_limiter::EndpointRateLimiter;
use crate::rate_limiter::{EndpointRateLimiter, RateBucketInfo};
use crate::scram::ServerSecret;
use crate::scram::threadpool::ThreadPool;
use crate::stream::{PqStream, Stream};
@@ -465,34 +568,46 @@ mod tests {
}
impl control_plane::ControlPlaneApi for Auth {
async fn get_role_access_control(
async fn get_role_secret(
&self,
_ctx: &RequestContext,
_endpoint: &crate::types::EndpointId,
_role: &crate::types::RoleName,
) -> Result<RoleAccessControl, control_plane::errors::GetAuthInfoError> {
Ok(RoleAccessControl {
secret: Some(self.secret.clone()),
})
_user_info: &super::ComputeUserInfo,
) -> Result<CachedRoleSecret, control_plane::errors::GetAuthInfoError> {
Ok(CachedRoleSecret::new_uncached(Some(self.secret.clone())))
}
async fn get_endpoint_access_control(
async fn get_allowed_ips(
&self,
_ctx: &RequestContext,
_endpoint: &crate::types::EndpointId,
_role: &crate::types::RoleName,
) -> Result<EndpointAccessControl, control_plane::errors::GetAuthInfoError> {
Ok(EndpointAccessControl {
allowed_ips: Arc::new(self.ips.clone()),
allowed_vpce: Arc::new(self.vpc_endpoint_ids.clone()),
flags: self.access_blocker_flags,
})
_user_info: &super::ComputeUserInfo,
) -> Result<CachedAllowedIps, control_plane::errors::GetAuthInfoError> {
Ok(CachedAllowedIps::new_uncached(Arc::new(self.ips.clone())))
}
async fn get_allowed_vpc_endpoint_ids(
&self,
_ctx: &RequestContext,
_user_info: &super::ComputeUserInfo,
) -> Result<CachedAllowedVpcEndpointIds, control_plane::errors::GetAuthInfoError> {
Ok(CachedAllowedVpcEndpointIds::new_uncached(Arc::new(
self.vpc_endpoint_ids.clone(),
)))
}
async fn get_block_public_or_vpc_access(
&self,
_ctx: &RequestContext,
_user_info: &super::ComputeUserInfo,
) -> Result<CachedAccessBlockerFlags, control_plane::errors::GetAuthInfoError> {
Ok(CachedAccessBlockerFlags::new_uncached(
self.access_blocker_flags.clone(),
))
}
async fn get_endpoint_jwks(
&self,
_ctx: &RequestContext,
_endpoint: &crate::types::EndpointId,
_endpoint: crate::types::EndpointId,
) -> Result<Vec<super::jwt::AuthRule>, control_plane::errors::GetEndpointJwksError>
{
unimplemented!()
@@ -511,6 +626,9 @@ mod tests {
jwks_cache: JwkCache::default(),
thread_pool: ThreadPool::new(1),
scram_protocol_timeout: std::time::Duration::from_secs(5),
rate_limiter_enabled: true,
rate_limiter: AuthRateLimiter::new(&RateBucketInfo::DEFAULT_AUTH_SET),
rate_limit_ip_subnet: 64,
ip_allowlist_check_enabled: true,
is_vpc_acccess_proxy: false,
is_auth_broker: false,
@@ -527,10 +645,55 @@ mod tests {
}
}
#[test]
fn masked_ip() {
let ip_a = IpAddr::V4([127, 0, 0, 1].into());
let ip_b = IpAddr::V4([127, 0, 0, 2].into());
let ip_c = IpAddr::V4([192, 168, 1, 101].into());
let ip_d = IpAddr::V4([192, 168, 1, 102].into());
let ip_e = IpAddr::V6("abcd:abcd:abcd:abcd:abcd:abcd:abcd:abcd".parse().unwrap());
let ip_f = IpAddr::V6("abcd:abcd:abcd:abcd:1234:abcd:abcd:abcd".parse().unwrap());
assert_ne!(MaskedIp::new(ip_a, 64), MaskedIp::new(ip_b, 64));
assert_ne!(MaskedIp::new(ip_a, 32), MaskedIp::new(ip_b, 32));
assert_eq!(MaskedIp::new(ip_a, 30), MaskedIp::new(ip_b, 30));
assert_eq!(MaskedIp::new(ip_c, 30), MaskedIp::new(ip_d, 30));
assert_ne!(MaskedIp::new(ip_e, 128), MaskedIp::new(ip_f, 128));
assert_eq!(MaskedIp::new(ip_e, 64), MaskedIp::new(ip_f, 64));
}
#[test]
fn test_default_auth_rate_limit_set() {
// these values used to exceed u32::MAX
assert_eq!(
RateBucketInfo::DEFAULT_AUTH_SET,
[
RateBucketInfo {
interval: Duration::from_secs(1),
max_rpi: 1000 * 4096,
},
RateBucketInfo {
interval: Duration::from_secs(60),
max_rpi: 600 * 4096 * 60,
},
RateBucketInfo {
interval: Duration::from_secs(600),
max_rpi: 300 * 4096 * 600,
}
]
);
for x in RateBucketInfo::DEFAULT_AUTH_SET {
let y = x.to_string().parse().unwrap();
assert_eq!(x, y);
}
}
#[tokio::test]
async fn auth_quirks_scram() {
let (mut client, server) = tokio::io::duplex(1024);
let mut stream = PqStream::new_skip_handshake(Stream::from_raw(server));
let mut stream = PqStream::new(Stream::from_raw(server), TaskTracker::new().token());
let ctx = RequestContext::test();
let api = Auth {
@@ -612,7 +775,7 @@ mod tests {
#[tokio::test]
async fn auth_quirks_cleartext() {
let (mut client, server) = tokio::io::duplex(1024);
let mut stream = PqStream::new_skip_handshake(Stream::from_raw(server));
let mut stream = PqStream::new(Stream::from_raw(server), TaskTracker::new().token());
let ctx = RequestContext::test();
let api = Auth {
@@ -666,7 +829,7 @@ mod tests {
#[tokio::test]
async fn auth_quirks_password_hack() {
let (mut client, server) = tokio::io::duplex(1024);
let mut stream = PqStream::new_skip_handshake(Stream::from_raw(server));
let mut stream = PqStream::new(Stream::from_raw(server), TaskTracker::new().token());
let ctx = RequestContext::test();
let api = Auth {

View File

@@ -5,6 +5,7 @@ use std::net::IpAddr;
use std::str::FromStr;
use itertools::Itertools;
use pq_proto::StartupMessageParams;
use thiserror::Error;
use tracing::{debug, warn};
@@ -12,7 +13,6 @@ use crate::auth::password_hack::parse_endpoint_param;
use crate::context::RequestContext;
use crate::error::{ReportableError, UserFacingError};
use crate::metrics::{Metrics, SniGroup, SniKind};
use crate::pqproto::StartupMessageParams;
use crate::proxy::NeonOptions;
use crate::serverless::{AUTH_BROKER_SNI, SERVERLESS_DRIVER_SNI};
use crate::types::{EndpointId, RoleName};

View File

@@ -1,8 +1,10 @@
//! Main authentication flow.
use std::io;
use std::sync::Arc;
use postgres_protocol::authentication::sasl::{SCRAM_SHA_256, SCRAM_SHA_256_PLUS};
use pq_proto::{BeAuthenticationSaslMessage, BeMessage, BeMessage as Be};
use tokio::io::{AsyncRead, AsyncWrite};
use tracing::info;
@@ -11,26 +13,35 @@ use super::{AuthError, PasswordHackPayload};
use crate::context::RequestContext;
use crate::control_plane::AuthSecret;
use crate::intern::EndpointIdInt;
use crate::pqproto::{BeAuthenticationSaslMessage, BeMessage};
use crate::sasl;
use crate::scram::threadpool::ThreadPool;
use crate::scram::{self};
use crate::stream::{PqStream, Stream};
use crate::tls::TlsServerEndPoint;
/// Every authentication selector is supposed to implement this trait.
pub(crate) trait AuthMethod {
/// Any authentication selector should provide initial backend message
/// containing auth method name and parameters, e.g. md5 salt.
fn first_message(&self, channel_binding: bool) -> BeMessage<'_>;
}
/// Initial state of [`AuthFlow`].
pub(crate) struct Begin;
/// Use [SCRAM](crate::scram)-based auth in [`AuthFlow`].
pub(crate) struct Scram<'a>(
pub(crate) &'a scram::ServerSecret,
pub(crate) &'a RequestContext,
);
impl Scram<'_> {
impl AuthMethod for Scram<'_> {
#[inline(always)]
fn first_message(&self, channel_binding: bool) -> BeMessage<'_> {
if channel_binding {
BeMessage::AuthenticationSasl(BeAuthenticationSaslMessage::Methods(scram::METHODS))
Be::AuthenticationSasl(BeAuthenticationSaslMessage::Methods(scram::METHODS))
} else {
BeMessage::AuthenticationSasl(BeAuthenticationSaslMessage::Methods(
Be::AuthenticationSasl(BeAuthenticationSaslMessage::Methods(
scram::METHODS_WITHOUT_PLUS,
))
}
@@ -41,6 +52,13 @@ impl Scram<'_> {
/// <https://github.com/neondatabase/cloud/issues/1620#issuecomment-1165332290>.
pub(crate) struct PasswordHack;
impl AuthMethod for PasswordHack {
#[inline(always)]
fn first_message(&self, _channel_binding: bool) -> BeMessage<'_> {
Be::AuthenticationCleartextPassword
}
}
/// Use clear-text password auth called `password` in docs
/// <https://www.postgresql.org/docs/current/auth-password.html>
pub(crate) struct CleartextPassword {
@@ -49,37 +67,53 @@ pub(crate) struct CleartextPassword {
pub(crate) secret: AuthSecret,
}
impl AuthMethod for CleartextPassword {
#[inline(always)]
fn first_message(&self, _channel_binding: bool) -> BeMessage<'_> {
Be::AuthenticationCleartextPassword
}
}
/// This wrapper for [`PqStream`] performs client authentication.
#[must_use]
pub(crate) struct AuthFlow<'a, S, State> {
/// The underlying stream which implements libpq's protocol.
stream: &'a mut PqStream<Stream<S>>,
/// State might contain ancillary data.
/// State might contain ancillary data (see [`Self::begin`]).
state: State,
tls_server_end_point: TlsServerEndPoint,
}
/// Initial state of the stream wrapper.
impl<'a, S: AsyncRead + AsyncWrite + Unpin, M> AuthFlow<'a, S, M> {
impl<'a, S: AsyncRead + AsyncWrite + Unpin> AuthFlow<'a, S, Begin> {
/// Create a new wrapper for client authentication.
pub(crate) fn new(stream: &'a mut PqStream<Stream<S>>, method: M) -> Self {
pub(crate) fn new(stream: &'a mut PqStream<Stream<S>>) -> Self {
let tls_server_end_point = stream.get_ref().tls_server_end_point();
Self {
stream,
state: method,
state: Begin,
tls_server_end_point,
}
}
/// Move to the next step by sending auth method's name & params to client.
pub(crate) async fn begin<M: AuthMethod>(self, method: M) -> io::Result<AuthFlow<'a, S, M>> {
self.stream
.write_message(&method.first_message(self.tls_server_end_point.supported()))
.await?;
Ok(AuthFlow {
stream: self.stream,
state: method,
tls_server_end_point: self.tls_server_end_point,
})
}
}
impl<S: AsyncRead + AsyncWrite + Unpin> AuthFlow<'_, S, PasswordHack> {
/// Perform user authentication. Raise an error in case authentication failed.
pub(crate) async fn get_password(self) -> super::Result<PasswordHackPayload> {
self.stream
.write_message(BeMessage::AuthenticationCleartextPassword);
self.stream.flush().await?;
let msg = self.stream.read_password_message().await?;
let password = msg
.strip_suffix(&[0])
@@ -99,10 +133,6 @@ impl<S: AsyncRead + AsyncWrite + Unpin> AuthFlow<'_, S, PasswordHack> {
impl<S: AsyncRead + AsyncWrite + Unpin> AuthFlow<'_, S, CleartextPassword> {
/// Perform user authentication. Raise an error in case authentication failed.
pub(crate) async fn authenticate(self) -> super::Result<sasl::Outcome<ComputeCredentialKeys>> {
self.stream
.write_message(BeMessage::AuthenticationCleartextPassword);
self.stream.flush().await?;
let msg = self.stream.read_password_message().await?;
let password = msg
.strip_suffix(&[0])
@@ -117,7 +147,7 @@ impl<S: AsyncRead + AsyncWrite + Unpin> AuthFlow<'_, S, CleartextPassword> {
.await?;
if let sasl::Outcome::Success(_) = &outcome {
self.stream.write_message(BeMessage::AuthenticationOk);
self.stream.write_message_noflush(&Be::AuthenticationOk)?;
}
Ok(outcome)
@@ -129,36 +159,42 @@ impl<S: AsyncRead + AsyncWrite + Unpin> AuthFlow<'_, S, Scram<'_>> {
/// Perform user authentication. Raise an error in case authentication failed.
pub(crate) async fn authenticate(self) -> super::Result<sasl::Outcome<scram::ScramKey>> {
let Scram(secret, ctx) = self.state;
let channel_binding = self.tls_server_end_point;
// send sasl message.
{
// pause the timer while we communicate with the client
let _paused = ctx.latency_timer_pause(crate::metrics::Waiting::Client);
// pause the timer while we communicate with the client
let _paused = ctx.latency_timer_pause(crate::metrics::Waiting::Client);
let sasl = self.state.first_message(channel_binding.supported());
self.stream.write_message(sasl);
self.stream.flush().await?;
// Initial client message contains the chosen auth method's name.
let msg = self.stream.read_password_message().await?;
let sasl = sasl::FirstMessage::parse(&msg)
.ok_or(AuthError::MalformedPassword("bad sasl message"))?;
// Currently, the only supported SASL method is SCRAM.
if !scram::METHODS.contains(&sasl.method) {
return Err(super::AuthError::bad_auth_method(sasl.method));
}
// complete sasl handshake.
sasl::authenticate(ctx, self.stream, |method| {
// Currently, the only supported SASL method is SCRAM.
match method {
SCRAM_SHA_256 => ctx.set_auth_method(crate::context::AuthMethod::ScramSha256),
SCRAM_SHA_256_PLUS => {
ctx.set_auth_method(crate::context::AuthMethod::ScramSha256Plus);
}
method => return Err(sasl::Error::BadAuthMethod(method.into())),
}
match sasl.method {
SCRAM_SHA_256 => ctx.set_auth_method(crate::context::AuthMethod::ScramSha256),
SCRAM_SHA_256_PLUS => ctx.set_auth_method(crate::context::AuthMethod::ScramSha256Plus),
_ => {}
}
// TODO: make this a metric instead
info!("client chooses {}", method);
// TODO: make this a metric instead
info!("client chooses {}", sasl.method);
Ok(scram::Exchange::new(secret, rand::random, channel_binding))
})
.await
.map_err(AuthError::Sasl)
let outcome = sasl::SaslStream::new(self.stream, sasl.message)
.authenticate(scram::Exchange::new(
secret,
rand::random,
self.tls_server_end_point,
))
.await?;
if let sasl::Outcome::Success(_) = &outcome {
self.stream.write_message_noflush(&Be::AuthenticationOk)?;
}
Ok(outcome)
}
}

View File

@@ -1,7 +1,7 @@
//! Client authentication mechanisms.
pub mod backend;
pub use backend::Backend;
pub use backend::{Backend, ControlPlaneWakeCompute};
mod credentials;
pub(crate) use credentials::{

View File

@@ -32,7 +32,9 @@ use crate::ext::TaskExt;
use crate::http::health_server::AppMetrics;
use crate::intern::RoleNameInt;
use crate::metrics::{Metrics, ThreadPoolMetrics};
use crate::rate_limiter::{EndpointRateLimiter, LeakyBucketConfig, RateBucketInfo};
use crate::rate_limiter::{
BucketRateLimiter, EndpointRateLimiter, LeakyBucketConfig, RateBucketInfo,
};
use crate::scram::threadpool::ThreadPool;
use crate::serverless::cancel_set::CancelSet;
use crate::serverless::{self, GlobalConnPoolOptions};
@@ -67,6 +69,15 @@ struct LocalProxyCliArgs {
/// Can be given multiple times for different bucket sizes.
#[clap(long, default_values_t = RateBucketInfo::DEFAULT_ENDPOINT_SET)]
user_rps_limit: Vec<RateBucketInfo>,
/// Whether the auth rate limiter actually takes effect (for testing)
#[clap(long, default_value_t = false, value_parser = clap::builder::BoolishValueParser::new(), action = clap::ArgAction::Set)]
auth_rate_limit_enabled: bool,
/// Authentication rate limiter max number of hashes per second.
#[clap(long, default_values_t = RateBucketInfo::DEFAULT_AUTH_SET)]
auth_rate_limit: Vec<RateBucketInfo>,
/// The IP subnet to use when considering whether two IP addresses are considered the same.
#[clap(long, default_value_t = 64)]
auth_rate_limit_ip_subnet: u8,
/// Whether to retry the connection to the compute node
#[clap(long, default_value = config::RetryConfig::CONNECT_TO_COMPUTE_DEFAULT_VALUES)]
connect_to_compute_retry: String,
@@ -271,6 +282,9 @@ fn build_config(args: &LocalProxyCliArgs) -> anyhow::Result<&'static ProxyConfig
jwks_cache: JwkCache::default(),
thread_pool: ThreadPool::new(0),
scram_protocol_timeout: Duration::from_secs(10),
rate_limiter_enabled: false,
rate_limiter: BucketRateLimiter::new(vec![]),
rate_limit_ip_subnet: 64,
ip_allowlist_check_enabled: true,
is_vpc_acccess_proxy: false,
is_auth_broker: false,

View File

@@ -4,9 +4,8 @@
//! This allows connecting to pods/services running in the same Kubernetes cluster from
//! the outside. Similar to an ingress controller for HTTPS.
use std::net::SocketAddr;
use std::path::Path;
use std::sync::Arc;
use std::{net::SocketAddr, sync::Arc};
use anyhow::{Context, anyhow, bail, ensure};
use clap::Arg;
@@ -18,20 +17,18 @@ use rustls::pki_types::{DnsName, PrivateKeyDer};
use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt};
use tokio::net::TcpListener;
use tokio_rustls::TlsConnector;
use tokio_rustls::server::TlsStream;
use tokio_util::sync::CancellationToken;
use tokio_util::task::task_tracker::TaskTrackerToken;
use tracing::{Instrument, error, info};
use utils::project_git_version;
use utils::sentry_init::init_sentry;
use crate::context::RequestContext;
use crate::metrics::{Metrics, ThreadPoolMetrics};
use crate::pqproto::FeStartupPacket;
use crate::protocol2::ConnectionInfo;
use crate::proxy::{
ErrorSource, TlsRequired, copy_bidirectional_client_compute, run_until_cancelled,
};
use crate::proxy::{ErrorSource, copy_bidirectional_client_compute, run_until_cancelled};
use crate::stream::{PqStream, Stream};
use crate::tls::TlsServerEndPoint;
project_git_version!(GIT_VERSION);
@@ -88,7 +85,7 @@ pub async fn run() -> anyhow::Result<()> {
.parse()?;
// Configure TLS
let tls_config = match (
let (tls_config, tls_server_end_point): (Arc<rustls::ServerConfig>, TlsServerEndPoint) = match (
args.get_one::<String>("tls-key"),
args.get_one::<String>("tls-cert"),
) {
@@ -121,6 +118,7 @@ pub async fn run() -> anyhow::Result<()> {
dest.clone(),
tls_config.clone(),
None,
tls_server_end_point,
proxy_listener,
cancellation_token.clone(),
))
@@ -130,6 +128,7 @@ pub async fn run() -> anyhow::Result<()> {
dest,
tls_config,
Some(compute_tls_config),
tls_server_end_point,
proxy_listener_compute_tls,
cancellation_token.clone(),
))
@@ -156,7 +155,7 @@ pub async fn run() -> anyhow::Result<()> {
pub(super) fn parse_tls(
key_path: &Path,
cert_path: &Path,
) -> anyhow::Result<Arc<rustls::ServerConfig>> {
) -> anyhow::Result<(Arc<rustls::ServerConfig>, TlsServerEndPoint)> {
let key = {
let key_bytes = std::fs::read(key_path).context("TLS key file")?;
@@ -189,6 +188,10 @@ pub(super) fn parse_tls(
})?
};
// needed for channel bindings
let first_cert = cert_chain.first().context("missing certificate")?;
let tls_server_end_point = TlsServerEndPoint::new(first_cert)?;
let tls_config =
rustls::ServerConfig::builder_with_provider(Arc::new(ring::default_provider()))
.with_protocol_versions(&[&rustls::version::TLS13, &rustls::version::TLS12])
@@ -197,13 +200,14 @@ pub(super) fn parse_tls(
.with_single_cert(cert_chain, key)?
.into();
Ok(tls_config)
Ok((tls_config, tls_server_end_point))
}
pub(super) async fn task_main(
dest_suffix: Arc<String>,
tls_config: Arc<rustls::ServerConfig>,
compute_tls_config: Option<Arc<rustls::ClientConfig>>,
tls_server_end_point: TlsServerEndPoint,
listener: tokio::net::TcpListener,
cancellation_token: CancellationToken,
) -> anyhow::Result<()> {
@@ -223,7 +227,8 @@ pub(super) async fn task_main(
let dest_suffix = Arc::clone(&dest_suffix);
let compute_tls_config = compute_tls_config.clone();
connections.spawn(
let tracker = connections.token();
tokio::spawn(
async move {
socket
.set_nodelay(true)
@@ -239,7 +244,16 @@ pub(super) async fn task_main(
crate::metrics::Protocol::SniRouter,
"sni",
);
handle_client(ctx, dest_suffix, tls_config, compute_tls_config, socket).await
handle_client(
ctx,
dest_suffix,
tls_config,
compute_tls_config,
tls_server_end_point,
socket,
tracker,
)
.await
}
.unwrap_or_else(|e| {
// Acknowledge that the task has finished with an error.
@@ -258,26 +272,59 @@ pub(super) async fn task_main(
Ok(())
}
const ERR_INSECURE_CONNECTION: &str = "connection is insecure (try using `sslmode=require`)";
async fn ssl_handshake<S: AsyncRead + AsyncWrite + Unpin>(
ctx: &RequestContext,
raw_stream: S,
tracker: TaskTrackerToken,
tls_config: Arc<rustls::ServerConfig>,
) -> anyhow::Result<TlsStream<S>> {
let (mut stream, msg) = PqStream::parse_startup(Stream::from_raw(raw_stream)).await?;
match msg {
FeStartupPacket::SslRequest { direct: None } => {
let raw = stream.accept_tls().await?;
tls_server_end_point: TlsServerEndPoint,
) -> anyhow::Result<(Stream<S>, TaskTrackerToken)> {
let mut stream = PqStream::new(Stream::from_raw(raw_stream), tracker);
Ok(raw
.upgrade(tls_config, !ctx.has_private_peer_addr())
.await?)
let msg = stream.read_startup_packet().await?;
use pq_proto::FeStartupPacket::SslRequest;
match msg {
SslRequest { direct: false } => {
stream
.write_message(&pq_proto::BeMessage::EncryptionResponse(true))
.await?;
// Upgrade raw stream into a secure TLS-backed stream.
// NOTE: We've consumed `tls`; this fact will be used later.
let (raw, read_buf, tracker) = stream.into_inner();
// TODO: Normally, client doesn't send any data before
// server says TLS handshake is ok and read_buf is empty.
// However, you could imagine pipelining of postgres
// SSLRequest + TLS ClientHello in one hunk similar to
// pipelining in our node js driver. We should probably
// support that by chaining read_buf with the stream.
if !read_buf.is_empty() {
bail!("data is sent before server replied with EncryptionResponse");
}
Ok((
Stream::Tls {
tls: Box::new(
raw.upgrade(tls_config, !ctx.has_private_peer_addr())
.await?,
),
tls_server_end_point,
},
tracker,
))
}
unexpected => {
info!(
?unexpected,
"unexpected startup packet, rejecting connection"
);
Err(stream.throw_error(TlsRequired, None).await)?
stream
.throw_error_str(ERR_INSECURE_CONNECTION, crate::error::ErrorKind::User, None)
.await?
}
}
}
@@ -287,18 +334,17 @@ async fn handle_client(
dest_suffix: Arc<String>,
tls_config: Arc<rustls::ServerConfig>,
compute_tls_config: Option<Arc<rustls::ClientConfig>>,
tls_server_end_point: TlsServerEndPoint,
stream: impl AsyncRead + AsyncWrite + Unpin,
tracker: TaskTrackerToken,
) -> anyhow::Result<()> {
let mut tls_stream = ssl_handshake(&ctx, stream, tls_config).await?;
let (mut tls_stream, _tracker) =
ssl_handshake(&ctx, stream, tracker, tls_config, tls_server_end_point).await?;
// Cut off first part of the SNI domain
// We receive required destination details in the format of
// `{k8s_service_name}--{k8s_namespace}--{port}.non-sni-domain`
let sni = tls_stream
.get_ref()
.1
.server_name()
.ok_or(anyhow!("SNI missing"))?;
let sni = tls_stream.sni_hostname().ok_or(anyhow!("SNI missing"))?;
let dest: Vec<&str> = sni
.split_once('.')
.context("invalid SNI")?

View File

@@ -20,7 +20,7 @@ use utils::sentry_init::init_sentry;
use utils::{project_build_tag, project_git_version};
use crate::auth::backend::jwt::JwkCache;
use crate::auth::backend::{ConsoleRedirectBackend, MaybeOwned};
use crate::auth::backend::{AuthRateLimiter, ConsoleRedirectBackend, MaybeOwned};
use crate::cancellation::{CancellationHandler, handle_cancel_messages};
use crate::config::{
self, AuthenticationConfig, CacheOptions, ComputeConfig, HttpConfig, ProjectInfoCacheOptions,
@@ -29,7 +29,9 @@ use crate::config::{
use crate::context::parquet::ParquetUploadArgs;
use crate::http::health_server::AppMetrics;
use crate::metrics::Metrics;
use crate::rate_limiter::{EndpointRateLimiter, RateBucketInfo, WakeComputeRateLimiter};
use crate::rate_limiter::{
EndpointRateLimiter, LeakyBucketConfig, RateBucketInfo, WakeComputeRateLimiter,
};
use crate::redis::connection_with_credentials_provider::ConnectionWithCredentialsProvider;
use crate::redis::kv_ops::RedisKVClient;
use crate::redis::{elasticache, notifications};
@@ -152,6 +154,15 @@ struct ProxyCliArgs {
/// Wake compute rate limiter max number of requests per second.
#[clap(long, default_values_t = RateBucketInfo::DEFAULT_SET)]
wake_compute_limit: Vec<RateBucketInfo>,
/// Whether the auth rate limiter actually takes effect (for testing)
#[clap(long, default_value_t = false, value_parser = clap::builder::BoolishValueParser::new(), action = clap::ArgAction::Set)]
auth_rate_limit_enabled: bool,
/// Authentication rate limiter max number of hashes per second.
#[clap(long, default_values_t = RateBucketInfo::DEFAULT_AUTH_SET)]
auth_rate_limit: Vec<RateBucketInfo>,
/// The IP subnet to use when considering whether two IP addresses are considered the same.
#[clap(long, default_value_t = 64)]
auth_rate_limit_ip_subnet: u8,
/// Redis rate limiter max number of requests per second.
#[clap(long, default_values_t = RateBucketInfo::DEFAULT_REDIS_SET)]
redis_rps_limit: Vec<RateBucketInfo>,
@@ -221,7 +232,8 @@ struct ProxyCliArgs {
is_private_access_proxy: bool,
/// Configure whether all incoming requests have a Proxy Protocol V2 packet.
#[clap(value_enum, long, default_value_t = ProxyProtocolV2::Rejected)]
// TODO(conradludgate): switch default to rejected or required once we've updated all deployments
#[clap(value_enum, long, default_value_t = ProxyProtocolV2::Supported)]
proxy_protocol_v2: ProxyProtocolV2,
/// Time the proxy waits for the webauth session to be confirmed by the control plane.
@@ -398,9 +410,22 @@ pub async fn run() -> anyhow::Result<()> {
Some(tx_cancel),
));
// bit of a hack - find the min rps and max rps supported and turn it into
// leaky bucket config instead
let max = args
.endpoint_rps_limit
.iter()
.map(|x| x.rps())
.max_by(f64::total_cmp)
.unwrap_or(EndpointRateLimiter::DEFAULT.max);
let rps = args
.endpoint_rps_limit
.iter()
.map(|x| x.rps())
.min_by(f64::total_cmp)
.unwrap_or(EndpointRateLimiter::DEFAULT.rps);
let endpoint_rate_limiter = Arc::new(EndpointRateLimiter::new_with_shards(
RateBucketInfo::to_leaky_bucket(&args.endpoint_rps_limit)
.unwrap_or(EndpointRateLimiter::DEFAULT),
LeakyBucketConfig { rps, max },
64,
));
@@ -451,7 +476,8 @@ pub async fn run() -> anyhow::Result<()> {
let key_path = args.tls_key.expect("already asserted it is set");
let cert_path = args.tls_cert.expect("already asserted it is set");
let tls_config = super::pg_sni_router::parse_tls(&key_path, &cert_path)?;
let (tls_config, tls_server_end_point) =
super::pg_sni_router::parse_tls(&key_path, &cert_path)?;
let dest = Arc::new(dest);
@@ -459,6 +485,7 @@ pub async fn run() -> anyhow::Result<()> {
dest.clone(),
tls_config.clone(),
None,
tls_server_end_point,
listen,
cancellation_token.clone(),
));
@@ -467,6 +494,7 @@ pub async fn run() -> anyhow::Result<()> {
dest,
tls_config,
Some(config.connect_to_compute.tls.clone()),
tls_server_end_point,
listen_tls,
cancellation_token.clone(),
));
@@ -653,6 +681,9 @@ fn build_config(args: &ProxyCliArgs) -> anyhow::Result<&'static ProxyConfig> {
jwks_cache: JwkCache::default(),
thread_pool,
scram_protocol_timeout: args.scram_protocol_timeout,
rate_limiter_enabled: args.auth_rate_limit_enabled,
rate_limiter: AuthRateLimiter::new(args.auth_rate_limit.clone()),
rate_limit_ip_subnet: args.auth_rate_limit_ip_subnet,
ip_allowlist_check_enabled: !args.is_private_access_proxy,
is_vpc_acccess_proxy: args.is_private_access_proxy,
is_auth_broker: args.is_auth_broker,

Some files were not shown because too many files have changed in this diff Show More