Merge branch 'main' into ruslan/subzero-integration

This commit is contained in:
Ruslan Talpa
2025-07-01 13:46:57 +03:00
105 changed files with 3256 additions and 1410 deletions

View File

@@ -4,6 +4,7 @@
!Cargo.lock
!Cargo.toml
!Makefile
!postgres.mk
!rust-toolchain.toml
!scripts/ninstall.sh
!docker-compose/run-tests.sh

View File

@@ -94,11 +94,6 @@ jobs:
run: |
make "neon-pg-ext-${{ matrix.postgres-version }}" -j$(sysctl -n hw.ncpu)
- name: Get postgres headers ${{ matrix.postgres-version }}
if: steps.cache_pg.outputs.cache-hit != 'true'
run: |
make postgres-headers-${{ matrix.postgres-version }} -j$(sysctl -n hw.ncpu)
- name: Upload "pg_install/${{ matrix.postgres-version }}" artifact
uses: actions/upload-artifact@ea165f8d65b6e75b540449e92b4886f43607fa02 # v4.6.2
with:
@@ -140,6 +135,12 @@ jobs:
name: pg_install--v17
path: pg_install/v17
# `actions/download-artifact` doesn't preserve permissions:
# https://github.com/actions/download-artifact?tab=readme-ov-file#permission-loss
- name: Make pg_install/v*/bin/* executable
run: |
chmod +x pg_install/v*/bin/*
- name: Cache walproposer-lib
id: cache_walproposer_lib
uses: actions/cache@5a3ec84eff668545956fd18022155c47e93e2684 # v4.2.3
@@ -167,7 +168,7 @@ jobs:
- name: Build walproposer-lib (only for v17)
if: steps.cache_walproposer_lib.outputs.cache-hit != 'true'
run:
make walproposer-lib -j$(sysctl -n hw.ncpu)
make walproposer-lib -j$(sysctl -n hw.ncpu) PG_INSTALL_CACHED=1
- name: Upload "build/walproposer-lib" artifact
uses: actions/upload-artifact@ea165f8d65b6e75b540449e92b4886f43607fa02 # v4.6.2

View File

@@ -69,7 +69,7 @@ jobs:
submodules: true
- name: Check for file changes
uses: step-security/paths-filter@v3
uses: dorny/paths-filter@de90cc6fb38fc0963ad72b210f1f284cd68cea36 # v3.0.2
id: files-changed
with:
token: ${{ secrets.GITHUB_TOKEN }}

View File

@@ -153,7 +153,7 @@ jobs:
PERF_TEST_RESULT_CONNSTR: "${{ secrets.PERF_TEST_RESULT_CONNSTR }}"
- name: Benchmark database maintenance
if: ${{ matrix.test_maintenance == 'true' }}
if: ${{ matrix.test_maintenance }}
uses: ./.github/actions/run-python-test-set
with:
build_type: ${{ env.BUILD_TYPE }}

View File

@@ -53,7 +53,7 @@ jobs:
submodules: true
- name: Check for Postgres changes
uses: step-security/paths-filter@v3
uses: dorny/paths-filter@1441771bbfdd59dcd748680ee64ebd8faab1a242 #v3
id: files_changed
with:
token: ${{ github.token }}

View File

@@ -34,7 +34,7 @@ jobs:
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
- uses: step-security/changed-files@3dbe17c78367e7d60f00d78ae6781a35be47b4a1 # v45.0.1
- uses: tj-actions/changed-files@ed68ef82c095e0d48ec87eccea555d944a631a4c # v46.0.5
id: python-src
with:
files: |
@@ -45,7 +45,7 @@ jobs:
poetry.lock
pyproject.toml
- uses: step-security/changed-files@3dbe17c78367e7d60f00d78ae6781a35be47b4a1 # v45.0.1
- uses: tj-actions/changed-files@ed68ef82c095e0d48ec87eccea555d944a631a4c # v46.0.5
id: rust-src
with:
files: |

View File

@@ -60,22 +60,23 @@ jobs:
} >> "$GITHUB_ENV"
- name: Run proxy-bench
run: ./${PROXY_BENCH_PATH}/run.sh
run: ${PROXY_BENCH_PATH}/run.sh
- name: Ingest Bench Results # neon repo script
if: success()
if: always()
run: |
mkdir -p $TEST_OUTPUT
python $NEON_DIR/scripts/proxy_bench_results_ingest.py --out $TEST_OUTPUT
- name: Push Metrics to Proxy perf database
if: success()
if: always()
env:
PERF_TEST_RESULT_CONNSTR: "${{ secrets.PROXY_TEST_RESULT_CONNSTR }}"
REPORT_FROM: $TEST_OUTPUT
run: $NEON_DIR/scripts/generate_and_push_perf_report.sh
- name: Docker cleanup
if: always()
run: docker compose down
- name: Notify Failure

21
Cargo.lock generated
View File

@@ -1285,6 +1285,7 @@ dependencies = [
"remote_storage",
"serde",
"serde_json",
"url",
"utils",
]
@@ -1322,6 +1323,7 @@ dependencies = [
"opentelemetry",
"opentelemetry_sdk",
"p256 0.13.2",
"pageserver_page_api",
"postgres",
"postgres_initdb",
"postgres_versioninfo",
@@ -1341,6 +1343,7 @@ dependencies = [
"tokio-postgres",
"tokio-stream",
"tokio-util",
"tonic 0.13.1",
"tower 0.5.2",
"tower-http",
"tower-otel",
@@ -4485,6 +4488,7 @@ dependencies = [
"postgres_backend",
"postgres_ffi_types",
"postgres_versioninfo",
"posthog_client_lite",
"rand 0.8.5",
"remote_storage",
"reqwest",
@@ -4495,6 +4499,7 @@ dependencies = [
"strum",
"strum_macros",
"thiserror 1.0.69",
"tracing",
"tracing-utils",
"utils",
]
@@ -4551,12 +4556,14 @@ dependencies = [
"bytes",
"futures",
"pageserver_api",
"postgres_ffi",
"postgres_ffi_types",
"prost 0.13.5",
"prost-types 0.13.5",
"strum",
"strum_macros",
"thiserror 1.0.69",
"tokio",
"tokio-util",
"tonic 0.13.1",
"tonic-build",
"utils",
@@ -5242,7 +5249,7 @@ dependencies = [
"petgraph",
"prettyplease",
"prost 0.13.5",
"prost-types 0.13.3",
"prost-types 0.13.5",
"regex",
"syn 2.0.100",
"tempfile",
@@ -5285,9 +5292,9 @@ dependencies = [
[[package]]
name = "prost-types"
version = "0.13.3"
version = "0.13.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "4759aa0d3a6232fb8dbdb97b61de2c20047c68aca932c7ed76da9d788508d670"
checksum = "52c2c1bf36ddb1a1c396b3601a3cec27c2462e45f07c386894ec3ccf5332bd16"
dependencies = [
"prost 0.13.5",
]
@@ -6919,6 +6926,7 @@ dependencies = [
"chrono",
"clap",
"clashmap",
"compute_api",
"control_plane",
"cron",
"diesel",
@@ -7772,7 +7780,7 @@ dependencies = [
"prettyplease",
"proc-macro2",
"prost-build 0.13.3",
"prost-types 0.13.3",
"prost-types 0.13.5",
"quote",
"syn 2.0.100",
]
@@ -7784,7 +7792,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f9687bd5bfeafebdded2356950f278bba8226f0b32109537c4253406e09aafe1"
dependencies = [
"prost 0.13.5",
"prost-types 0.13.3",
"prost-types 0.13.5",
"tokio",
"tokio-stream",
"tonic 0.13.1",
@@ -8813,7 +8821,6 @@ dependencies = [
"num-iter",
"num-rational",
"num-traits",
"once_cell",
"p256 0.13.2",
"parquet",
"prettyplease",

View File

@@ -152,6 +152,7 @@ pprof = { version = "0.14", features = ["criterion", "flamegraph", "frame-pointe
procfs = "0.16"
prometheus = {version = "0.13", default-features=false, features = ["process"]} # removes protobuf dependency
prost = "0.13.5"
prost-types = "0.13.5"
rand = "0.8"
redis = { version = "0.29.2", features = ["tokio-rustls-comp", "keep-alive"] }
regex = "1.10.2"
@@ -199,7 +200,7 @@ tokio-postgres-rustls = "0.12.0"
tokio-rustls = { version = "0.26.0", default-features = false, features = ["tls12", "ring"]}
tokio-stream = "0.1"
tokio-tar = "0.3"
tokio-util = { version = "0.7.10", features = ["io", "rt"] }
tokio-util = { version = "0.7.10", features = ["io", "io-util", "rt"] }
toml = "0.8"
toml_edit = "0.22"
tonic = { version = "0.13.1", default-features = false, features = ["channel", "codegen", "gzip", "prost", "router", "server", "tls-ring", "tls-native-roots", "zstd"] }

View File

@@ -40,6 +40,7 @@ COPY --chown=nonroot vendor/postgres-v16 vendor/postgres-v16
COPY --chown=nonroot vendor/postgres-v17 vendor/postgres-v17
COPY --chown=nonroot pgxn pgxn
COPY --chown=nonroot Makefile Makefile
COPY --chown=nonroot postgres.mk postgres.mk
COPY --chown=nonroot scripts/ninstall.sh scripts/ninstall.sh
ENV BUILD_TYPE=release

129
Makefile
View File

@@ -4,11 +4,14 @@ ROOT_PROJECT_DIR := $(dir $(abspath $(lastword $(MAKEFILE_LIST))))
# managers.
POSTGRES_INSTALL_DIR ?= $(ROOT_PROJECT_DIR)/pg_install/
# Supported PostgreSQL versions
POSTGRES_VERSIONS = v17 v16 v15 v14
# CARGO_BUILD_FLAGS: Extra flags to pass to `cargo build`. `--locked`
# and `--features testing` are popular examples.
#
# CARGO_PROFILE: You can also set to override the cargo profile to
# use. By default, it is derived from BUILD_TYPE.
# CARGO_PROFILE: Set to override the cargo profile to use. By default,
# it is derived from BUILD_TYPE.
# All intermediate build artifacts are stored here.
BUILD_DIR := build
@@ -95,91 +98,24 @@ CACHEDIR_TAG_CONTENTS := "Signature: 8a477f597d28d172789f06886806bc55"
# Top level Makefile to build Neon and PostgreSQL
#
.PHONY: all
all: neon postgres neon-pg-ext
all: neon postgres-install neon-pg-ext
### Neon Rust bits
#
# The 'postgres_ffi' depends on the Postgres headers.
.PHONY: neon
neon: postgres-headers walproposer-lib cargo-target-dir
neon: postgres-headers-install walproposer-lib cargo-target-dir
+@echo "Compiling Neon"
$(CARGO_CMD_PREFIX) cargo build $(CARGO_BUILD_FLAGS) $(CARGO_PROFILE)
.PHONY: cargo-target-dir
cargo-target-dir:
# https://github.com/rust-lang/cargo/issues/14281
mkdir -p target
test -e target/CACHEDIR.TAG || echo "$(CACHEDIR_TAG_CONTENTS)" > target/CACHEDIR.TAG
### PostgreSQL parts
# Some rules are duplicated for Postgres v14 and 15. We may want to refactor
# to avoid the duplication in the future, but it's tolerable for now.
#
$(BUILD_DIR)/%/config.status:
mkdir -p $(BUILD_DIR)
test -e $(BUILD_DIR)/CACHEDIR.TAG || echo "$(CACHEDIR_TAG_CONTENTS)" > $(BUILD_DIR)/CACHEDIR.TAG
+@echo "Configuring Postgres $* build"
@test -s $(ROOT_PROJECT_DIR)/vendor/postgres-$*/configure || { \
echo "\nPostgres submodule not found in $(ROOT_PROJECT_DIR)/vendor/postgres-$*/, execute "; \
echo "'git submodule update --init --recursive --depth 2 --progress .' in project root.\n"; \
exit 1; }
mkdir -p $(BUILD_DIR)/$*
VERSION=$*; \
EXTRA_VERSION=$$(cd $(ROOT_PROJECT_DIR)/vendor/postgres-$$VERSION && git rev-parse HEAD); \
(cd $(BUILD_DIR)/$$VERSION && \
env PATH="$(EXTRA_PATH_OVERRIDES):$$PATH" $(ROOT_PROJECT_DIR)/vendor/postgres-$$VERSION/configure \
CFLAGS='$(PG_CFLAGS)' LDFLAGS='$(PG_LDFLAGS)' \
$(PG_CONFIGURE_OPTS) --with-extra-version=" ($$EXTRA_VERSION)" \
--prefix=$(abspath $(POSTGRES_INSTALL_DIR))/$$VERSION > configure.log)
# nicer alias to run 'configure'
# Note: I've been unable to use templates for this part of our configuration.
# I'm not sure why it wouldn't work, but this is the only place (apart from
# the "build-all-versions" entry points) where direct mention of PostgreSQL
# versions is used.
.PHONY: postgres-configure-v17
postgres-configure-v17: $(BUILD_DIR)/v17/config.status
.PHONY: postgres-configure-v16
postgres-configure-v16: $(BUILD_DIR)/v16/config.status
.PHONY: postgres-configure-v15
postgres-configure-v15: $(BUILD_DIR)/v15/config.status
.PHONY: postgres-configure-v14
postgres-configure-v14: $(BUILD_DIR)/v14/config.status
# Install the PostgreSQL header files into $(POSTGRES_INSTALL_DIR)/<version>/include
.PHONY: postgres-headers-%
postgres-headers-%: postgres-configure-%
+@echo "Installing PostgreSQL $* headers"
$(MAKE) -C $(BUILD_DIR)/$*/src/include MAKELEVEL=0 install
# Compile and install PostgreSQL
.PHONY: postgres-%
postgres-%: postgres-configure-% \
postgres-headers-% # to prevent `make install` conflicts with neon's `postgres-headers`
+@echo "Compiling PostgreSQL $*"
$(MAKE) -C $(BUILD_DIR)/$* MAKELEVEL=0 install
+@echo "Compiling pg_prewarm $*"
$(MAKE) -C $(BUILD_DIR)/$*/contrib/pg_prewarm install
+@echo "Compiling pg_buffercache $*"
$(MAKE) -C $(BUILD_DIR)/$*/contrib/pg_buffercache install
+@echo "Compiling pg_visibility $*"
$(MAKE) -C $(BUILD_DIR)/$*/contrib/pg_visibility install
+@echo "Compiling pageinspect $*"
$(MAKE) -C $(BUILD_DIR)/$*/contrib/pageinspect install
+@echo "Compiling pg_trgm $*"
$(MAKE) -C $(BUILD_DIR)/$*/contrib/pg_trgm install
+@echo "Compiling amcheck $*"
$(MAKE) -C $(BUILD_DIR)/$*/contrib/amcheck install
+@echo "Compiling test_decoding $*"
$(MAKE) -C $(BUILD_DIR)/$*/contrib/test_decoding install
.PHONY: postgres-check-%
postgres-check-%: postgres-%
$(MAKE) -C $(BUILD_DIR)/$* MAKELEVEL=0 check
.PHONY: neon-pg-ext-%
neon-pg-ext-%: postgres-%
neon-pg-ext-%: postgres-install-%
+@echo "Compiling neon-specific Postgres extensions for $*"
mkdir -p $(BUILD_DIR)/pgxn-$*
$(MAKE) PG_CONFIG=$(POSTGRES_INSTALL_DIR)/$*/bin/pg_config COPT='$(COPT)' \
@@ -218,39 +154,14 @@ ifeq ($(UNAME_S),Linux)
pg_crc32c.o
endif
# Shorthand to call neon-pg-ext-% target for all Postgres versions
.PHONY: neon-pg-ext
neon-pg-ext: \
neon-pg-ext-v14 \
neon-pg-ext-v15 \
neon-pg-ext-v16 \
neon-pg-ext-v17
# shorthand to build all Postgres versions
.PHONY: postgres
postgres: \
postgres-v14 \
postgres-v15 \
postgres-v16 \
postgres-v17
.PHONY: postgres-headers
postgres-headers: \
postgres-headers-v14 \
postgres-headers-v15 \
postgres-headers-v16 \
postgres-headers-v17
.PHONY: postgres-check
postgres-check: \
postgres-check-v14 \
postgres-check-v15 \
postgres-check-v16 \
postgres-check-v17
neon-pg-ext: $(foreach pg_version,$(POSTGRES_VERSIONS),neon-pg-ext-$(pg_version))
# This removes everything
.PHONY: distclean
distclean:
$(RM) -r $(POSTGRES_INSTALL_DIR)
$(RM) -r $(POSTGRES_INSTALL_DIR) $(BUILD_DIR)
$(CARGO_CMD_PREFIX) cargo clean
.PHONY: fmt
@@ -298,3 +209,19 @@ neon-pgindent: postgres-v17-pg-bsd-indent neon-pg-ext-v17
.PHONY: setup-pre-commit-hook
setup-pre-commit-hook:
ln -s -f $(ROOT_PROJECT_DIR)/pre-commit.py .git/hooks/pre-commit
# Targets for building PostgreSQL are defined in postgres.mk.
#
# But if the caller has indicated that PostgreSQL is already
# installed, by setting the PG_INSTALL_CACHED variable, skip it.
ifdef PG_INSTALL_CACHED
postgres-install: skip-install
$(foreach pg_version,$(POSTGRES_VERSIONS),postgres-install-$(pg_version)): skip-install
postgres-headers-install:
+@echo "Skipping installation of PostgreSQL headers because PG_INSTALL_CACHED is set"
skip-install:
+@echo "Skipping PostgreSQL installation because PG_INSTALL_CACHED is set"
else
include postgres.mk
endif

View File

@@ -165,6 +165,7 @@ RUN curl -fsSL \
&& rm sql_exporter.tar.gz
# protobuf-compiler (protoc)
# Keep the version the same as in compute/compute-node.Dockerfile
ENV PROTOC_VERSION=25.1
RUN curl -fsSL "https://github.com/protocolbuffers/protobuf/releases/download/v${PROTOC_VERSION}/protoc-${PROTOC_VERSION}-linux-$(uname -m | sed 's/aarch64/aarch_64/g').zip" -o "protoc.zip" \
&& unzip -q protoc.zip -d protoc \
@@ -179,7 +180,7 @@ RUN curl -sL "https://github.com/peak/s5cmd/releases/download/v${S5CMD_VERSION}/
&& mv s5cmd /usr/local/bin/s5cmd
# LLVM
ENV LLVM_VERSION=19
ENV LLVM_VERSION=20
RUN curl -fsSL 'https://apt.llvm.org/llvm-snapshot.gpg.key' | apt-key add - \
&& echo "deb http://apt.llvm.org/${DEBIAN_VERSION}/ llvm-toolchain-${DEBIAN_VERSION}-${LLVM_VERSION} main" > /etc/apt/sources.list.d/llvm.stable.list \
&& apt update \
@@ -292,7 +293,7 @@ WORKDIR /home/nonroot
# Rust
# Please keep the version of llvm (installed above) in sync with rust llvm (`rustc --version --verbose | grep LLVM`)
ENV RUSTC_VERSION=1.87.0
ENV RUSTC_VERSION=1.88.0
ENV RUSTUP_HOME="/home/nonroot/.rustup"
ENV PATH="/home/nonroot/.cargo/bin:${PATH}"
ARG RUSTFILT_VERSION=0.2.1

View File

@@ -115,6 +115,9 @@ ARG EXTENSIONS=all
FROM $BASE_IMAGE_SHA AS build-deps
ARG DEBIAN_VERSION
# Keep in sync with build-tools.Dockerfile
ENV PROTOC_VERSION=25.1
# Use strict mode for bash to catch errors early
SHELL ["/bin/bash", "-euo", "pipefail", "-c"]
@@ -149,8 +152,14 @@ RUN case $DEBIAN_VERSION in \
libclang-dev \
jsonnet \
$VERSION_INSTALLS \
&& apt clean && rm -rf /var/lib/apt/lists/* && \
useradd -ms /bin/bash nonroot -b /home
&& apt clean && rm -rf /var/lib/apt/lists/* \
&& useradd -ms /bin/bash nonroot -b /home \
# Install protoc from binary release, since Debian's versions are too old.
&& curl -fsSL "https://github.com/protocolbuffers/protobuf/releases/download/v${PROTOC_VERSION}/protoc-${PROTOC_VERSION}-linux-$(uname -m | sed 's/aarch64/aarch_64/g').zip" -o "protoc.zip" \
&& unzip -q protoc.zip -d protoc \
&& mv protoc/bin/protoc /usr/local/bin/protoc \
&& mv protoc/include/google /usr/local/include/google \
&& rm -rf protoc.zip protoc
#########################################################################################
#
@@ -1170,7 +1179,7 @@ COPY --from=pgrag-src /ext-src/ /ext-src/
# Install it using virtual environment, because Python 3.11 (the default version on Debian 12 (Bookworm)) complains otherwise
WORKDIR /ext-src/onnxruntime-src
RUN apt update && apt install --no-install-recommends --no-install-suggests -y \
python3 python3-pip python3-venv protobuf-compiler && \
python3 python3-pip python3-venv && \
apt clean && rm -rf /var/lib/apt/lists/* && \
python3 -m venv venv && \
. venv/bin/activate && \

View File

@@ -38,6 +38,7 @@ once_cell.workspace = true
opentelemetry.workspace = true
opentelemetry_sdk.workspace = true
p256 = { version = "0.13", features = ["pem"] }
pageserver_page_api.workspace = true
postgres.workspace = true
regex.workspace = true
reqwest = { workspace = true, features = ["json"] }
@@ -53,6 +54,7 @@ tokio = { workspace = true, features = ["rt", "rt-multi-thread"] }
tokio-postgres.workspace = true
tokio-util.workspace = true
tokio-stream.workspace = true
tonic.workspace = true
tower-otel.workspace = true
tracing.workspace = true
tracing-opentelemetry.workspace = true

View File

@@ -36,6 +36,8 @@
use std::ffi::OsString;
use std::fs::File;
use std::process::exit;
use std::sync::Arc;
use std::sync::atomic::AtomicU64;
use std::sync::mpsc;
use std::thread;
use std::time::Duration;
@@ -190,7 +192,9 @@ fn main() -> Result<()> {
cgroup: cli.cgroup,
#[cfg(target_os = "linux")]
vm_monitor_addr: cli.vm_monitor_addr,
installed_extensions_collection_interval: cli.installed_extensions_collection_interval,
installed_extensions_collection_interval: Arc::new(AtomicU64::new(
cli.installed_extensions_collection_interval,
)),
},
config,
)?;

View File

@@ -6,7 +6,7 @@ use compute_api::responses::{
LfcPrewarmState, TlsConfig,
};
use compute_api::spec::{
ComputeAudit, ComputeFeature, ComputeMode, ComputeSpec, ExtVersion, PgIdent,
ComputeAudit, ComputeFeature, ComputeMode, ComputeSpec, ExtVersion, PageserverProtocol, PgIdent,
};
use futures::StreamExt;
use futures::future::join_all;
@@ -15,17 +15,17 @@ use itertools::Itertools;
use nix::sys::signal::{Signal, kill};
use nix::unistd::Pid;
use once_cell::sync::Lazy;
use pageserver_page_api::{self as page_api, BaseBackupCompression};
use postgres;
use postgres::NoTls;
use postgres::error::SqlState;
use remote_storage::{DownloadError, RemotePath};
use std::collections::{HashMap, HashSet};
use std::net::SocketAddr;
use std::os::unix::fs::{PermissionsExt, symlink};
use std::path::Path;
use std::process::{Command, Stdio};
use std::str::FromStr;
use std::sync::atomic::{AtomicU32, Ordering};
use std::sync::atomic::{AtomicU32, AtomicU64, Ordering};
use std::sync::{Arc, Condvar, Mutex, RwLock};
use std::time::{Duration, Instant};
use std::{env, fs};
@@ -36,6 +36,7 @@ use utils::id::{TenantId, TimelineId};
use utils::lsn::Lsn;
use utils::measured_stream::MeasuredReader;
use utils::pid_file;
use utils::shard::{ShardCount, ShardIndex, ShardNumber};
use crate::configurator::launch_configurator;
use crate::disk_quota::set_disk_quota;
@@ -69,6 +70,7 @@ pub static BUILD_TAG: Lazy<String> = Lazy::new(|| {
.unwrap_or(BUILD_TAG_DEFAULT)
.to_string()
});
const DEFAULT_INSTALLED_EXTENSIONS_COLLECTION_INTERVAL: u64 = 3600;
/// Static configuration params that don't change after startup. These mostly
/// come from the CLI args, or are derived from them.
@@ -102,7 +104,7 @@ pub struct ComputeNodeParams {
pub remote_ext_base_url: Option<Url>,
/// Interval for installed extensions collection
pub installed_extensions_collection_interval: u64,
pub installed_extensions_collection_interval: Arc<AtomicU64>,
}
/// Compute node info shared across several `compute_ctl` threads.
@@ -125,6 +127,9 @@ pub struct ComputeNode {
// key: ext_archive_name, value: started download time, download_completed?
pub ext_download_progress: RwLock<HashMap<String, (DateTime<Utc>, bool)>>,
pub compute_ctl_config: ComputeCtlConfig,
/// Handle to the extension stats collection task
extension_stats_task: Mutex<Option<tokio::task::JoinHandle<()>>>,
}
// store some metrics about download size that might impact startup time
@@ -218,7 +223,8 @@ pub struct ParsedSpec {
pub pageserver_connstr: String,
pub safekeeper_connstrings: Vec<String>,
pub storage_auth_token: Option<String>,
pub endpoint_storage_addr: Option<SocketAddr>,
/// k8s dns name and port
pub endpoint_storage_addr: Option<String>,
pub endpoint_storage_token: Option<String>,
}
@@ -313,13 +319,10 @@ impl TryFrom<ComputeSpec> for ParsedSpec {
.or(Err("invalid timeline id"))?
};
let endpoint_storage_addr: Option<SocketAddr> = spec
let endpoint_storage_addr: Option<String> = spec
.endpoint_storage_addr
.clone()
.or_else(|| spec.cluster.settings.find("neon.endpoint_storage_addr"))
.unwrap_or_default()
.parse()
.ok();
.or_else(|| spec.cluster.settings.find("neon.endpoint_storage_addr"));
let endpoint_storage_token = spec
.endpoint_storage_token
.clone()
@@ -429,6 +432,7 @@ impl ComputeNode {
state_changed: Condvar::new(),
ext_download_progress: RwLock::new(HashMap::new()),
compute_ctl_config: config.compute_ctl_config,
extension_stats_task: Mutex::new(None),
})
}
@@ -516,6 +520,9 @@ impl ComputeNode {
None
};
// Terminate the extension stats collection task
this.terminate_extension_stats_task();
// Terminate the vm_monitor so it releases the file watcher on
// /sys/fs/cgroup/neon-postgres.
// Note: the vm-monitor only runs on linux because it requires cgroups.
@@ -998,13 +1005,80 @@ impl ComputeNode {
Ok(())
}
// Get basebackup from the libpq connection to pageserver using `connstr` and
// unarchive it to `pgdata` directory overriding all its previous content.
/// Fetches a basebackup from the Pageserver using the compute state's Pageserver connstring and
/// unarchives it to `pgdata` directory, replacing any existing contents.
#[instrument(skip_all, fields(%lsn))]
fn try_get_basebackup(&self, compute_state: &ComputeState, lsn: Lsn) -> Result<()> {
let spec = compute_state.pspec.as_ref().expect("spec must be set");
let start_time = Instant::now();
let shard0_connstr = spec.pageserver_connstr.split(',').next().unwrap();
let started = Instant::now();
let (connected, size) = match PageserverProtocol::from_connstring(shard0_connstr)? {
PageserverProtocol::Libpq => self.try_get_basebackup_libpq(spec, lsn)?,
PageserverProtocol::Grpc => self.try_get_basebackup_grpc(spec, lsn)?,
};
let mut state = self.state.lock().unwrap();
state.metrics.pageserver_connect_micros =
connected.duration_since(started).as_micros() as u64;
state.metrics.basebackup_bytes = size as u64;
state.metrics.basebackup_ms = started.elapsed().as_millis() as u64;
Ok(())
}
/// Fetches a basebackup via gRPC. The connstring must use grpc://. Returns the timestamp when
/// the connection was established, and the (compressed) size of the basebackup.
fn try_get_basebackup_grpc(&self, spec: &ParsedSpec, lsn: Lsn) -> Result<(Instant, usize)> {
let shard0_connstr = spec
.pageserver_connstr
.split(',')
.next()
.unwrap()
.to_string();
let shard_index = match spec.pageserver_connstr.split(',').count() as u8 {
0 | 1 => ShardIndex::unsharded(),
count => ShardIndex::new(ShardNumber(0), ShardCount(count)),
};
let (reader, connected) = tokio::runtime::Handle::current().block_on(async move {
let mut client = page_api::Client::new(
shard0_connstr,
spec.tenant_id,
spec.timeline_id,
shard_index,
spec.storage_auth_token.clone(),
None, // NB: base backups use payload compression
)
.await?;
let connected = Instant::now();
let reader = client
.get_base_backup(page_api::GetBaseBackupRequest {
lsn: (lsn != Lsn(0)).then_some(lsn),
compression: BaseBackupCompression::Gzip,
replica: spec.spec.mode != ComputeMode::Primary,
full: false,
})
.await?;
anyhow::Ok((reader, connected))
})?;
let mut reader = MeasuredReader::new(tokio_util::io::SyncIoBridge::new(reader));
// Set `ignore_zeros` so that unpack() reads the entire stream and doesn't just stop at the
// end-of-archive marker. If the server errors, the tar::Builder drop handler will write an
// end-of-archive marker before the error is emitted, and we would not see the error.
let mut ar = tar::Archive::new(flate2::read::GzDecoder::new(&mut reader));
ar.set_ignore_zeros(true);
ar.unpack(&self.params.pgdata)?;
Ok((connected, reader.get_byte_count()))
}
/// Fetches a basebackup via libpq. The connstring must use postgresql://. Returns the timestamp
/// when the connection was established, and the (compressed) size of the basebackup.
fn try_get_basebackup_libpq(&self, spec: &ParsedSpec, lsn: Lsn) -> Result<(Instant, usize)> {
let shard0_connstr = spec.pageserver_connstr.split(',').next().unwrap();
let mut config = postgres::Config::from_str(shard0_connstr)?;
@@ -1018,16 +1092,14 @@ impl ComputeNode {
}
config.application_name("compute_ctl");
if let Some(spec) = &compute_state.pspec {
config.options(&format!(
"-c neon.compute_mode={}",
spec.spec.mode.to_type_str()
));
}
config.options(&format!(
"-c neon.compute_mode={}",
spec.spec.mode.to_type_str()
));
// Connect to pageserver
let mut client = config.connect(NoTls)?;
let pageserver_connect_micros = start_time.elapsed().as_micros() as u64;
let connected = Instant::now();
let basebackup_cmd = match lsn {
Lsn(0) => {
@@ -1064,16 +1136,13 @@ impl ComputeNode {
// Set `ignore_zeros` so that unpack() reads all the Copy data and
// doesn't stop at the end-of-archive marker. Otherwise, if the server
// sends an Error after finishing the tarball, we will not notice it.
// The tar::Builder drop handler will write an end-of-archive marker
// before emitting the error, and we would not see it otherwise.
let mut ar = tar::Archive::new(flate2::read::GzDecoder::new(&mut bufreader));
ar.set_ignore_zeros(true);
ar.unpack(&self.params.pgdata)?;
// Report metrics
let mut state = self.state.lock().unwrap();
state.metrics.pageserver_connect_micros = pageserver_connect_micros;
state.metrics.basebackup_bytes = measured_reader.get_byte_count() as u64;
state.metrics.basebackup_ms = start_time.elapsed().as_millis() as u64;
Ok(())
Ok((connected, measured_reader.get_byte_count()))
}
// Gets the basebackup in a retry loop
@@ -1610,6 +1679,8 @@ impl ComputeNode {
tls_config = self.compute_ctl_config.tls.clone();
}
self.update_installed_extensions_collection_interval(&spec);
let max_concurrent_connections = self.max_service_connections(compute_state, &spec);
// Merge-apply spec & changes to PostgreSQL state.
@@ -1674,6 +1745,8 @@ impl ComputeNode {
let tls_config = self.tls_config(&spec);
self.update_installed_extensions_collection_interval(&spec);
if let Some(ref pgbouncer_settings) = spec.pgbouncer_settings {
info!("tuning pgbouncer");
@@ -2278,10 +2351,20 @@ LIMIT 100",
}
pub fn spawn_extension_stats_task(&self) {
// Cancel any existing task
if let Some(handle) = self.extension_stats_task.lock().unwrap().take() {
handle.abort();
}
let conf = self.tokio_conn_conf.clone();
let installed_extensions_collection_interval =
self.params.installed_extensions_collection_interval;
tokio::spawn(async move {
let atomic_interval = self.params.installed_extensions_collection_interval.clone();
let mut installed_extensions_collection_interval =
2 * atomic_interval.load(std::sync::atomic::Ordering::SeqCst);
info!(
"[NEON_EXT_SPAWN] Spawning background installed extensions worker with Timeout: {}",
installed_extensions_collection_interval
);
let handle = tokio::spawn(async move {
// An initial sleep is added to ensure that two collections don't happen at the same time.
// The first collection happens during compute startup.
tokio::time::sleep(tokio::time::Duration::from_secs(
@@ -2294,8 +2377,48 @@ LIMIT 100",
loop {
interval.tick().await;
let _ = installed_extensions(conf.clone()).await;
// Acquire a read lock on the compute spec and then update the interval if necessary
interval = tokio::time::interval(tokio::time::Duration::from_secs(std::cmp::max(
installed_extensions_collection_interval,
2 * atomic_interval.load(std::sync::atomic::Ordering::SeqCst),
)));
installed_extensions_collection_interval = interval.period().as_secs();
}
});
// Store the new task handle
*self.extension_stats_task.lock().unwrap() = Some(handle);
}
fn terminate_extension_stats_task(&self) {
if let Some(handle) = self.extension_stats_task.lock().unwrap().take() {
handle.abort();
}
}
fn update_installed_extensions_collection_interval(&self, spec: &ComputeSpec) {
// Update the interval for collecting installed extensions statistics
// If the value is -1, we never suspend so set the value to default collection.
// If the value is 0, it means default, we will just continue to use the default.
if spec.suspend_timeout_seconds == -1 || spec.suspend_timeout_seconds == 0 {
info!(
"[NEON_EXT_INT_UPD] Spec Timeout: {}, New Timeout: {}",
spec.suspend_timeout_seconds, DEFAULT_INSTALLED_EXTENSIONS_COLLECTION_INTERVAL
);
self.params.installed_extensions_collection_interval.store(
DEFAULT_INSTALLED_EXTENSIONS_COLLECTION_INTERVAL,
std::sync::atomic::Ordering::SeqCst,
);
} else {
info!(
"[NEON_EXT_INT_UPD] Spec Timeout: {}",
spec.suspend_timeout_seconds
);
self.params.installed_extensions_collection_interval.store(
spec.suspend_timeout_seconds as u64,
std::sync::atomic::Ordering::SeqCst,
);
}
}
}

View File

@@ -4,7 +4,9 @@ use std::thread;
use std::time::{Duration, SystemTime};
use anyhow::{Result, bail};
use compute_api::spec::ComputeMode;
use compute_api::spec::{ComputeMode, PageserverProtocol};
use itertools::Itertools as _;
use pageserver_page_api as page_api;
use postgres::{NoTls, SimpleQueryMessage};
use tracing::{info, warn};
use utils::id::{TenantId, TimelineId};
@@ -76,25 +78,17 @@ fn acquire_lsn_lease_with_retry(
loop {
// Note: List of pageservers is dynamic, need to re-read configs before each attempt.
let configs = {
let (connstrings, auth) = {
let state = compute.state.lock().unwrap();
let spec = state.pspec.as_ref().expect("spec must be set");
let conn_strings = spec.pageserver_connstr.split(',');
conn_strings
.map(|connstr| {
let mut config = postgres::Config::from_str(connstr).expect("Invalid connstr");
if let Some(storage_auth_token) = &spec.storage_auth_token {
config.password(storage_auth_token.clone());
}
config
})
.collect::<Vec<_>>()
(
spec.pageserver_connstr.clone(),
spec.storage_auth_token.clone(),
)
};
let result = try_acquire_lsn_lease(tenant_id, timeline_id, lsn, &configs);
let result =
try_acquire_lsn_lease(&connstrings, auth.as_deref(), tenant_id, timeline_id, lsn);
match result {
Ok(Some(res)) => {
return Ok(res);
@@ -116,68 +110,104 @@ fn acquire_lsn_lease_with_retry(
}
}
/// Tries to acquire an LSN lease through PS page_service API.
/// Tries to acquire LSN leases on all Pageserver shards.
fn try_acquire_lsn_lease(
connstrings: &str,
auth: Option<&str>,
tenant_id: TenantId,
timeline_id: TimelineId,
lsn: Lsn,
configs: &[postgres::Config],
) -> Result<Option<SystemTime>> {
fn get_valid_until(
config: &postgres::Config,
tenant_shard_id: TenantShardId,
timeline_id: TimelineId,
lsn: Lsn,
) -> Result<Option<SystemTime>> {
let mut client = config.connect(NoTls)?;
let cmd = format!("lease lsn {tenant_shard_id} {timeline_id} {lsn} ");
let res = client.simple_query(&cmd)?;
let msg = match res.first() {
Some(msg) => msg,
None => bail!("empty response"),
};
let row = match msg {
SimpleQueryMessage::Row(row) => row,
_ => bail!("error parsing lsn lease response"),
let connstrings = connstrings.split(',').collect_vec();
let shard_count = connstrings.len();
let mut leases = Vec::new();
for (shard_number, &connstring) in connstrings.iter().enumerate() {
let tenant_shard_id = match shard_count {
0 | 1 => TenantShardId::unsharded(tenant_id),
shard_count => TenantShardId {
tenant_id,
shard_number: ShardNumber(shard_number as u8),
shard_count: ShardCount::new(shard_count as u8),
},
};
// Note: this will be None if a lease is explicitly not granted.
let valid_until_str = row.get("valid_until");
let valid_until = valid_until_str.map(|s| {
SystemTime::UNIX_EPOCH
.checked_add(Duration::from_millis(u128::from_str(s).unwrap() as u64))
.expect("Time larger than max SystemTime could handle")
});
Ok(valid_until)
let lease = match PageserverProtocol::from_connstring(connstring)? {
PageserverProtocol::Libpq => {
acquire_lsn_lease_libpq(connstring, auth, tenant_shard_id, timeline_id, lsn)?
}
PageserverProtocol::Grpc => {
acquire_lsn_lease_grpc(connstring, auth, tenant_shard_id, timeline_id, lsn)?
}
};
leases.push(lease);
}
let shard_count = configs.len();
Ok(leases.into_iter().min().flatten())
}
let valid_until = if shard_count > 1 {
configs
.iter()
.enumerate()
.map(|(shard_number, config)| {
let tenant_shard_id = TenantShardId {
tenant_id,
shard_count: ShardCount::new(shard_count as u8),
shard_number: ShardNumber(shard_number as u8),
};
get_valid_until(config, tenant_shard_id, timeline_id, lsn)
})
.collect::<Result<Vec<Option<SystemTime>>>>()?
.into_iter()
.min()
.unwrap()
} else {
get_valid_until(
&configs[0],
TenantShardId::unsharded(tenant_id),
timeline_id,
lsn,
)?
/// Acquires an LSN lease on a single shard, using the libpq API. The connstring must use a
/// postgresql:// scheme.
fn acquire_lsn_lease_libpq(
connstring: &str,
auth: Option<&str>,
tenant_shard_id: TenantShardId,
timeline_id: TimelineId,
lsn: Lsn,
) -> Result<Option<SystemTime>> {
let mut config = postgres::Config::from_str(connstring)?;
if let Some(auth) = auth {
config.password(auth);
}
let mut client = config.connect(NoTls)?;
let cmd = format!("lease lsn {tenant_shard_id} {timeline_id} {lsn} ");
let res = client.simple_query(&cmd)?;
let msg = match res.first() {
Some(msg) => msg,
None => bail!("empty response"),
};
let row = match msg {
SimpleQueryMessage::Row(row) => row,
_ => bail!("error parsing lsn lease response"),
};
// Note: this will be None if a lease is explicitly not granted.
let valid_until_str = row.get("valid_until");
let valid_until = valid_until_str.map(|s| {
SystemTime::UNIX_EPOCH
.checked_add(Duration::from_millis(u128::from_str(s).unwrap() as u64))
.expect("Time larger than max SystemTime could handle")
});
Ok(valid_until)
}
/// Acquires an LSN lease on a single shard, using the gRPC API. The connstring must use a
/// grpc:// scheme.
fn acquire_lsn_lease_grpc(
connstring: &str,
auth: Option<&str>,
tenant_shard_id: TenantShardId,
timeline_id: TimelineId,
lsn: Lsn,
) -> Result<Option<SystemTime>> {
tokio::runtime::Handle::current().block_on(async move {
let mut client = page_api::Client::new(
connstring.to_string(),
tenant_shard_id.tenant_id,
timeline_id,
tenant_shard_id.to_index(),
auth.map(String::from),
None,
)
.await?;
let req = page_api::LeaseLsnRequest { lsn };
match client.lease_lsn(req).await {
Ok(expires) => Ok(Some(expires)),
// Lease couldn't be acquired because the LSN has been garbage collected.
Err(err) if err.code() == tonic::Code::FailedPrecondition => Ok(None),
Err(err) => Err(err.into()),
}
})
}

View File

@@ -3,7 +3,8 @@
"timestamp": "2021-05-23T18:25:43.511Z",
"operation_uuid": "0f657b36-4b0f-4a2d-9c2e-1dcd615e7d8b",
"suspend_timeout_seconds": 3600,
"cluster": {
"cluster_id": "test-cluster-42",
"name": "Zenith Test",

View File

@@ -16,9 +16,9 @@ use std::time::Duration;
use anyhow::{Context, Result, anyhow, bail};
use clap::Parser;
use compute_api::requests::ComputeClaimsScope;
use compute_api::spec::ComputeMode;
use compute_api::spec::{ComputeMode, PageserverProtocol};
use control_plane::broker::StorageBroker;
use control_plane::endpoint::{ComputeControlPlane, EndpointTerminateMode, PageserverProtocol};
use control_plane::endpoint::{ComputeControlPlane, EndpointTerminateMode};
use control_plane::endpoint_storage::{ENDPOINT_STORAGE_DEFAULT_ADDR, EndpointStorage};
use control_plane::local_env;
use control_plane::local_env::{
@@ -1649,7 +1649,9 @@ async fn handle_endpoint(subcmd: &EndpointCmd, env: &local_env::LocalEnv) -> Res
// If --safekeepers argument is given, use only the listed
// safekeeper nodes; otherwise all from the env.
let safekeepers = parse_safekeepers(&args.safekeepers)?;
endpoint.reconfigure(pageservers, None, safekeepers).await?;
endpoint
.reconfigure(Some(pageservers), None, safekeepers, None)
.await?;
}
EndpointCmd::Stop(args) => {
let endpoint_id = &args.endpoint_id;

View File

@@ -56,8 +56,8 @@ use compute_api::responses::{
TlsConfig,
};
use compute_api::spec::{
Cluster, ComputeAudit, ComputeFeature, ComputeMode, ComputeSpec, Database, PgIdent,
RemoteExtSpec, Role,
Cluster, ComputeAudit, ComputeFeature, ComputeMode, ComputeSpec, Database, PageserverProtocol,
PgIdent, RemoteExtSpec, Role,
};
use jsonwebtoken::jwk::{
AlgorithmParameters, CommonParameters, EllipticCurve, Jwk, JwkSet, KeyAlgorithm, KeyOperations,
@@ -373,29 +373,6 @@ impl std::fmt::Display for EndpointTerminateMode {
}
}
/// Protocol used to connect to a Pageserver.
#[derive(Clone, Copy, Debug)]
pub enum PageserverProtocol {
Libpq,
Grpc,
}
impl PageserverProtocol {
/// Returns the URL scheme for the protocol, used in connstrings.
pub fn scheme(&self) -> &'static str {
match self {
Self::Libpq => "postgresql",
Self::Grpc => "grpc",
}
}
}
impl Display for PageserverProtocol {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.write_str(self.scheme())
}
}
impl Endpoint {
fn from_dir_entry(entry: std::fs::DirEntry, env: &LocalEnv) -> Result<Endpoint> {
if !entry.file_type()?.is_dir() {
@@ -803,6 +780,7 @@ impl Endpoint {
endpoint_storage_addr: Some(endpoint_storage_addr),
endpoint_storage_token: Some(endpoint_storage_token),
autoprewarm: false,
suspend_timeout_seconds: -1, // Only used in neon_local.
};
// this strange code is needed to support respec() in tests
@@ -997,12 +975,11 @@ impl Endpoint {
pub async fn reconfigure(
&self,
pageservers: Vec<(PageserverProtocol, Host, u16)>,
pageservers: Option<Vec<(PageserverProtocol, Host, u16)>>,
stripe_size: Option<ShardStripeSize>,
safekeepers: Option<Vec<NodeId>>,
safekeeper_generation: Option<SafekeeperGeneration>,
) -> Result<()> {
anyhow::ensure!(!pageservers.is_empty(), "no pageservers provided");
let (mut spec, compute_ctl_config) = {
let config_path = self.endpoint_path().join("config.json");
let file = std::fs::File::open(config_path)?;
@@ -1014,16 +991,24 @@ impl Endpoint {
let postgresql_conf = self.read_postgresql_conf()?;
spec.cluster.postgresql_conf = Some(postgresql_conf);
let pageserver_connstr = Self::build_pageserver_connstr(&pageservers);
spec.pageserver_connstring = Some(pageserver_connstr);
if stripe_size.is_some() {
spec.shard_stripe_size = stripe_size.map(|s| s.0 as usize);
// If pageservers are not specified, don't change them.
if let Some(pageservers) = pageservers {
anyhow::ensure!(!pageservers.is_empty(), "no pageservers provided");
let pageserver_connstr = Self::build_pageserver_connstr(&pageservers);
spec.pageserver_connstring = Some(pageserver_connstr);
if stripe_size.is_some() {
spec.shard_stripe_size = stripe_size.map(|s| s.0 as usize);
}
}
// If safekeepers are not specified, don't change them.
if let Some(safekeepers) = safekeepers {
let safekeeper_connstrings = self.build_safekeepers_connstrs(safekeepers)?;
spec.safekeeper_connstrings = safekeeper_connstrings;
if let Some(g) = safekeeper_generation {
spec.safekeepers_generation = Some(g.into_inner());
}
}
let client = reqwest::Client::builder()
@@ -1061,6 +1046,24 @@ impl Endpoint {
}
}
pub async fn reconfigure_pageservers(
&self,
pageservers: Vec<(PageserverProtocol, Host, u16)>,
stripe_size: Option<ShardStripeSize>,
) -> Result<()> {
self.reconfigure(Some(pageservers), stripe_size, None, None)
.await
}
pub async fn reconfigure_safekeepers(
&self,
safekeepers: Vec<NodeId>,
generation: SafekeeperGeneration,
) -> Result<()> {
self.reconfigure(None, None, Some(safekeepers), Some(generation))
.await
}
pub async fn stop(
&self,
mode: EndpointTerminateMode,

View File

@@ -212,7 +212,7 @@ pub struct NeonStorageControllerConf {
pub use_local_compute_notifications: bool,
pub timeline_safekeeper_count: Option<i64>,
pub timeline_safekeeper_count: Option<usize>,
pub posthog_config: Option<PostHogConfig>,

View File

@@ -638,7 +638,13 @@ impl StorageController {
args.push("--timelines-onto-safekeepers".to_string());
}
if let Some(sk_cnt) = self.config.timeline_safekeeper_count {
// neon_local is used in test environments where we often have less than 3 safekeepers.
if self.config.timeline_safekeeper_count.is_some() || self.env.safekeepers.len() < 3 {
let sk_cnt = self
.config
.timeline_safekeeper_count
.unwrap_or(self.env.safekeepers.len());
args.push(format!("--timeline-safekeeper-count={sk_cnt}"));
}

View File

@@ -4,6 +4,7 @@
"timestamp": "2022-10-12T18:00:00.000Z",
"operation_uuid": "0f657b36-4b0f-4a2d-9c2e-1dcd615e7d8c",
"suspend_timeout_seconds": -1,
"cluster": {
"cluster_id": "docker_compose",

View File

@@ -12,6 +12,7 @@ jsonwebtoken.workspace = true
serde.workspace = true
serde_json.workspace = true
regex.workspace = true
url.workspace = true
utils = { path = "../utils" }
remote_storage = { version = "0.1", path = "../remote_storage/" }

View File

@@ -4,11 +4,14 @@
//! provide it by calling the compute_ctl's `/compute_ctl` endpoint, or
//! compute_ctl can fetch it by calling the control plane's API.
use std::collections::HashMap;
use std::fmt::Display;
use anyhow::anyhow;
use indexmap::IndexMap;
use regex::Regex;
use remote_storage::RemotePath;
use serde::{Deserialize, Serialize};
use url::Url;
use utils::id::{TenantId, TimelineId};
use utils::lsn::Lsn;
@@ -181,6 +184,11 @@ pub struct ComputeSpec {
/// Download LFC state from endpoint_storage and pass it to Postgres on startup
#[serde(default)]
pub autoprewarm: bool,
/// Suspend timeout in seconds.
///
/// We use this value to derive other values, such as the installed extensions metric.
pub suspend_timeout_seconds: i64,
}
/// Feature flag to signal `compute_ctl` to enable certain experimental functionality.
@@ -429,6 +437,47 @@ pub struct JwksSettings {
pub jwt_audience: Option<String>,
}
/// Protocol used to connect to a Pageserver. Parsed from the connstring scheme.
#[derive(Clone, Copy, Debug, Default)]
pub enum PageserverProtocol {
/// The original protocol based on libpq and COPY. Uses postgresql:// or postgres:// scheme.
#[default]
Libpq,
/// A newer, gRPC-based protocol. Uses grpc:// scheme.
Grpc,
}
impl PageserverProtocol {
/// Parses the protocol from a connstring scheme. Defaults to Libpq if no scheme is given.
/// Errors if the connstring is an invalid URL.
pub fn from_connstring(connstring: &str) -> anyhow::Result<Self> {
let scheme = match Url::parse(connstring) {
Ok(url) => url.scheme().to_lowercase(),
Err(url::ParseError::RelativeUrlWithoutBase) => return Ok(Self::default()),
Err(err) => return Err(anyhow!("invalid connstring URL: {err}")),
};
match scheme.as_str() {
"postgresql" | "postgres" => Ok(Self::Libpq),
"grpc" => Ok(Self::Grpc),
scheme => Err(anyhow!("invalid protocol scheme: {scheme}")),
}
}
/// Returns the URL scheme for the protocol, for use in connstrings.
pub fn scheme(&self) -> &'static str {
match self {
Self::Libpq => "postgresql",
Self::Grpc => "grpc",
}
}
}
impl Display for PageserverProtocol {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.write_str(self.scheme())
}
}
#[cfg(test)]
mod tests {
use std::fs::File;

View File

@@ -3,6 +3,7 @@
"timestamp": "2021-05-23T18:25:43.511Z",
"operation_uuid": "0f657b36-4b0f-4a2d-9c2e-1dcd615e7d8b",
"suspend_timeout_seconds": 3600,
"cluster": {
"cluster_id": "test-cluster-42",

View File

@@ -19,6 +19,7 @@ byteorder.workspace = true
utils.workspace = true
postgres_ffi_types.workspace = true
postgres_versioninfo.workspace = true
posthog_client_lite.workspace = true
enum-map.workspace = true
strum.workspace = true
strum_macros.workspace = true
@@ -29,12 +30,13 @@ humantime-serde.workspace = true
chrono = { workspace = true, features = ["serde"] }
itertools.workspace = true
storage_broker.workspace = true
camino = {workspace = true, features = ["serde1"]}
camino = { workspace = true, features = ["serde1"] }
remote_storage.workspace = true
postgres_backend.workspace = true
nix = {workspace = true, optional = true}
nix = { workspace = true, optional = true }
reqwest.workspace = true
rand.workspace = true
tracing.workspace = true
tracing-utils.workspace = true
once_cell.workspace = true

View File

@@ -4,6 +4,7 @@ use camino::Utf8PathBuf;
mod tests;
use const_format::formatcp;
use posthog_client_lite::PostHogClientConfig;
pub const DEFAULT_PG_LISTEN_PORT: u16 = 64000;
pub const DEFAULT_PG_LISTEN_ADDR: &str = formatcp!("127.0.0.1:{DEFAULT_PG_LISTEN_PORT}");
pub const DEFAULT_HTTP_LISTEN_PORT: u16 = 9898;
@@ -68,15 +69,25 @@ impl Display for NodeMetadata {
#[derive(Debug, Clone, PartialEq, Eq, serde::Serialize, serde::Deserialize)]
pub struct PostHogConfig {
/// PostHog project ID
pub project_id: String,
#[serde(default)]
#[serde(skip_serializing_if = "Option::is_none")]
pub project_id: Option<String>,
/// Server-side (private) API key
pub server_api_key: String,
#[serde(default)]
#[serde(skip_serializing_if = "Option::is_none")]
pub server_api_key: Option<String>,
/// Client-side (public) API key
pub client_api_key: String,
#[serde(default)]
#[serde(skip_serializing_if = "Option::is_none")]
pub client_api_key: Option<String>,
/// Private API URL
pub private_api_url: String,
#[serde(default)]
#[serde(skip_serializing_if = "Option::is_none")]
pub private_api_url: Option<String>,
/// Public API URL
pub public_api_url: String,
#[serde(default)]
#[serde(skip_serializing_if = "Option::is_none")]
pub public_api_url: Option<String>,
/// Refresh interval for the feature flag spec.
/// The storcon will push the feature flag spec to the pageserver. If the pageserver does not receive
/// the spec for `refresh_interval`, it will fetch the spec from the PostHog API.
@@ -86,6 +97,33 @@ pub struct PostHogConfig {
pub refresh_interval: Option<Duration>,
}
impl PostHogConfig {
pub fn try_into_posthog_config(self) -> Result<PostHogClientConfig, &'static str> {
let Some(project_id) = self.project_id else {
return Err("project_id is required");
};
let Some(server_api_key) = self.server_api_key else {
return Err("server_api_key is required");
};
let Some(client_api_key) = self.client_api_key else {
return Err("client_api_key is required");
};
let Some(private_api_url) = self.private_api_url else {
return Err("private_api_url is required");
};
let Some(public_api_url) = self.public_api_url else {
return Err("public_api_url is required");
};
Ok(PostHogClientConfig {
project_id,
server_api_key,
client_api_key,
private_api_url,
public_api_url,
})
}
}
/// `pageserver.toml`
///
/// We use serde derive with `#[serde(default)]` to generate a deserializer
@@ -371,6 +409,9 @@ pub struct BasebackupCacheConfig {
// TODO(diko): support max_entry_size_bytes.
// pub max_entry_size_bytes: u64,
pub max_size_entries: usize,
/// Size of the channel used to send prepare requests to the basebackup cache worker.
/// If exceeded, new prepare requests will be dropped.
pub prepare_channel_size: usize,
}
impl Default for BasebackupCacheConfig {
@@ -379,7 +420,8 @@ impl Default for BasebackupCacheConfig {
cleanup_period: Duration::from_secs(60),
max_total_size_bytes: 1024 * 1024 * 1024, // 1 GiB
// max_entry_size_bytes: 16 * 1024 * 1024, // 16 MiB
max_size_entries: 1000,
max_size_entries: 10000,
prepare_channel_size: 100,
}
}
}

View File

@@ -546,6 +546,11 @@ pub struct TimelineImportRequest {
pub sk_set: Vec<NodeId>,
}
#[derive(serde::Serialize, serde::Deserialize, Clone)]
pub struct TimelineSafekeeperMigrateRequest {
pub new_sk_set: Vec<NodeId>,
}
#[cfg(test)]
mod test {
use serde_json;

View File

@@ -21,7 +21,9 @@ use utils::{completion, serde_system_time};
use crate::config::Ratio;
use crate::key::{CompactKey, Key};
use crate::shard::{DEFAULT_STRIPE_SIZE, ShardCount, ShardStripeSize, TenantShardId};
use crate::shard::{
DEFAULT_STRIPE_SIZE, ShardCount, ShardIdentity, ShardStripeSize, TenantShardId,
};
/// The state of a tenant in this pageserver.
///
@@ -475,7 +477,7 @@ pub struct TenantShardSplitResponse {
}
/// Parameters that apply to all shards in a tenant. Used during tenant creation.
#[derive(Serialize, Deserialize, Debug)]
#[derive(Clone, Copy, Serialize, Deserialize, Debug)]
#[serde(deny_unknown_fields)]
pub struct ShardParameters {
pub count: ShardCount,
@@ -497,6 +499,15 @@ impl Default for ShardParameters {
}
}
impl From<ShardIdentity> for ShardParameters {
fn from(identity: ShardIdentity) -> Self {
Self {
count: identity.count,
stripe_size: identity.stripe_size,
}
}
}
#[derive(Debug, Default, Clone, Eq, PartialEq)]
pub enum FieldPatch<T> {
Upsert(T),

View File

@@ -37,6 +37,7 @@ use std::hash::{Hash, Hasher};
pub use ::utils::shard::*;
use postgres_ffi_types::forknum::INIT_FORKNUM;
use serde::{Deserialize, Serialize};
use utils::critical;
use crate::key::Key;
use crate::models::ShardParameters;
@@ -179,7 +180,7 @@ impl ShardIdentity {
/// For use when creating ShardIdentity instances for new shards, where a creation request
/// specifies the ShardParameters that apply to all shards.
pub fn from_params(number: ShardNumber, params: &ShardParameters) -> Self {
pub fn from_params(number: ShardNumber, params: ShardParameters) -> Self {
Self {
number,
count: params.count,
@@ -188,6 +189,17 @@ impl ShardIdentity {
}
}
/// Asserts that the given shard identities are equal. Changes to shard parameters will likely
/// result in data corruption.
pub fn assert_equal(&self, other: ShardIdentity) {
if self != &other {
// TODO: for now, we're conservative and just log errors in production. Turn this into a
// real assertion when we're confident it doesn't misfire, and also reject requests that
// attempt to change it with an error response.
critical!("shard identity mismatch: {self:?} != {other:?}");
}
}
fn is_broken(&self) -> bool {
self.layout == LAYOUT_BROKEN
}

View File

@@ -210,7 +210,7 @@ pub struct TimelineStatus {
}
/// Request to switch membership configuration.
#[derive(Serialize, Deserialize)]
#[derive(Clone, Serialize, Deserialize)]
#[serde(transparent)]
pub struct TimelineMembershipSwitchRequest {
pub mconf: Configuration,
@@ -221,6 +221,8 @@ pub struct TimelineMembershipSwitchRequest {
pub struct TimelineMembershipSwitchResponse {
pub previous_conf: Configuration,
pub current_conf: Configuration,
pub term: Term,
pub flush_lsn: Lsn,
}
#[derive(Clone, Copy, Serialize, Deserialize)]

View File

@@ -86,6 +86,14 @@ pub enum GateError {
GateClosed,
}
impl GateError {
pub fn is_cancel(&self) -> bool {
match self {
GateError::GateClosed => true,
}
}
}
impl Default for Gate {
fn default() -> Self {
Self {

View File

@@ -9,12 +9,14 @@ anyhow.workspace = true
bytes.workspace = true
futures.workspace = true
pageserver_api.workspace = true
postgres_ffi.workspace = true
postgres_ffi_types.workspace = true
prost.workspace = true
prost-types.workspace = true
strum.workspace = true
strum_macros.workspace = true
thiserror.workspace = true
tokio.workspace = true
tokio-util.workspace = true
tonic.workspace = true
utils.workspace = true
workspace_hack.workspace = true

View File

@@ -35,6 +35,8 @@
syntax = "proto3";
package page_api;
import "google/protobuf/timestamp.proto";
service PageService {
// Returns whether a relation exists.
rpc CheckRelExists(CheckRelExistsRequest) returns (CheckRelExistsResponse);
@@ -64,6 +66,10 @@ service PageService {
// Fetches an SLRU segment.
rpc GetSlruSegment (GetSlruSegmentRequest) returns (GetSlruSegmentResponse);
// Acquires or extends a lease on the given LSN. This guarantees that the Pageserver won't garbage
// collect the LSN until the lease expires. Must be acquired on all relevant shards.
rpc LeaseLsn (LeaseLsnRequest) returns (LeaseLsnResponse);
}
// The LSN a request should read at.
@@ -252,3 +258,17 @@ message GetSlruSegmentRequest {
message GetSlruSegmentResponse {
bytes segment = 1;
}
// Acquires or extends a lease on the given LSN. This guarantees that the Pageserver won't garbage
// collect the LSN until the lease expires. Must be acquired on all relevant shards.
message LeaseLsnRequest {
// The LSN to lease. Can't be 0 or below the current GC cutoff.
uint64 lsn = 1;
}
// Lease acquisition response. If the lease could not be granted because the LSN has already been
// garbage collected, a FailedPrecondition status will be returned instead.
message LeaseLsnResponse {
// The lease expiration time.
google.protobuf.Timestamp expires = 1;
}

View File

@@ -1,8 +1,7 @@
use std::convert::TryInto;
use bytes::Bytes;
use futures::TryStreamExt;
use futures::{Stream, StreamExt};
use anyhow::Result;
use futures::{Stream, StreamExt as _, TryStreamExt as _};
use tokio::io::AsyncRead;
use tokio_util::io::StreamReader;
use tonic::metadata::AsciiMetadataValue;
use tonic::metadata::errors::InvalidMetadataValue;
use tonic::transport::Channel;
@@ -12,8 +11,6 @@ use utils::id::TenantId;
use utils::id::TimelineId;
use utils::shard::ShardIndex;
use anyhow::Result;
use crate::model;
use crate::proto;
@@ -69,6 +66,7 @@ impl tonic::service::Interceptor for AuthInterceptor {
Ok(req)
}
}
#[derive(Clone)]
pub struct Client {
client: proto::PageServiceClient<
@@ -120,22 +118,15 @@ impl Client {
pub async fn get_base_backup(
&mut self,
req: model::GetBaseBackupRequest,
) -> Result<impl Stream<Item = Result<Bytes, tonic::Status>> + 'static, tonic::Status> {
let proto_req = proto::GetBaseBackupRequest::from(req);
let response_stream: Streaming<proto::GetBaseBackupResponseChunk> =
self.client.get_base_backup(proto_req).await?.into_inner();
// TODO: Consider dechunking internally
let domain_stream = response_stream.map(|chunk_res| {
chunk_res.and_then(|proto_chunk| {
proto_chunk.try_into().map_err(|e| {
tonic::Status::internal(format!("Failed to convert response chunk: {e}"))
})
})
});
Ok(domain_stream)
) -> Result<impl AsyncRead + use<>, tonic::Status> {
let req = proto::GetBaseBackupRequest::from(req);
let chunks = self.client.get_base_backup(req).await?.into_inner();
let reader = StreamReader::new(
chunks
.map_ok(|resp| resp.chunk)
.map_err(std::io::Error::other),
);
Ok(reader)
}
/// Returns the total size of a database, as # of bytes.
@@ -196,4 +187,17 @@ impl Client {
let response = self.client.get_slru_segment(proto_req).await?;
Ok(response.into_inner().try_into()?)
}
/// Acquires or extends a lease on the given LSN. This guarantees that the Pageserver won't
/// garbage collect the LSN until the lease expires. Must be acquired on all relevant shards.
///
/// Returns the lease expiration time, or a FailedPrecondition status if the lease could not be
/// acquired because the LSN has already been garbage collected.
pub async fn lease_lsn(
&mut self,
req: model::LeaseLsnRequest,
) -> Result<model::LeaseLsnResponse, tonic::Status> {
let req = proto::LeaseLsnRequest::from(req);
Ok(self.client.lease_lsn(req).await?.into_inner().try_into()?)
}
}

View File

@@ -16,10 +16,11 @@
//! stream combinators without dealing with errors, and avoids validating the same message twice.
use std::fmt::Display;
use std::time::{Duration, SystemTime, UNIX_EPOCH};
use bytes::Bytes;
use postgres_ffi::Oid;
// TODO: split out Lsn, RelTag, SlruKind, Oid and other basic types to a separate crate, to avoid
use postgres_ffi_types::Oid;
// TODO: split out Lsn, RelTag, SlruKind and other basic types to a separate crate, to avoid
// pulling in all of their other crate dependencies when building the client.
use utils::lsn::Lsn;
@@ -703,3 +704,54 @@ impl From<GetSlruSegmentResponse> for proto::GetSlruSegmentResponse {
// SlruKind is defined in pageserver_api::reltag.
pub type SlruKind = pageserver_api::reltag::SlruKind;
/// Acquires or extends a lease on the given LSN. This guarantees that the Pageserver won't garbage
/// collect the LSN until the lease expires.
pub struct LeaseLsnRequest {
/// The LSN to lease.
pub lsn: Lsn,
}
impl TryFrom<proto::LeaseLsnRequest> for LeaseLsnRequest {
type Error = ProtocolError;
fn try_from(pb: proto::LeaseLsnRequest) -> Result<Self, Self::Error> {
if pb.lsn == 0 {
return Err(ProtocolError::Missing("lsn"));
}
Ok(Self { lsn: Lsn(pb.lsn) })
}
}
impl From<LeaseLsnRequest> for proto::LeaseLsnRequest {
fn from(request: LeaseLsnRequest) -> Self {
Self { lsn: request.lsn.0 }
}
}
/// Lease expiration time. If the lease could not be granted because the LSN has already been
/// garbage collected, a FailedPrecondition status will be returned instead.
pub type LeaseLsnResponse = SystemTime;
impl TryFrom<proto::LeaseLsnResponse> for LeaseLsnResponse {
type Error = ProtocolError;
fn try_from(pb: proto::LeaseLsnResponse) -> Result<Self, Self::Error> {
let expires = pb.expires.ok_or(ProtocolError::Missing("expires"))?;
UNIX_EPOCH
.checked_add(Duration::new(expires.seconds as u64, expires.nanos as u32))
.ok_or_else(|| ProtocolError::invalid("expires", expires))
}
}
impl From<LeaseLsnResponse> for proto::LeaseLsnResponse {
fn from(response: LeaseLsnResponse) -> Self {
let expires = response.duration_since(UNIX_EPOCH).unwrap_or_default();
Self {
expires: Some(prost_types::Timestamp {
seconds: expires.as_secs() as i64,
nanos: expires.subsec_nanos() as i32,
}),
}
}
}

View File

@@ -355,9 +355,6 @@ impl Client for GrpcClient {
full: false,
compression: self.compression,
};
let stream = self.inner.get_base_backup(req).await?;
Ok(Box::pin(StreamReader::new(
stream.map_err(std::io::Error::other),
)))
Ok(Box::pin(self.inner.get_base_backup(req).await?))
}
}

View File

@@ -6,7 +6,7 @@ use metrics::core::{AtomicU64, GenericCounter};
use pageserver_api::{config::BasebackupCacheConfig, models::TenantState};
use tokio::{
io::{AsyncWriteExt, BufWriter},
sync::mpsc::{UnboundedReceiver, UnboundedSender},
sync::mpsc::{Receiver, Sender, error::TrySendError},
};
use tokio_util::sync::CancellationToken;
use utils::{
@@ -19,8 +19,8 @@ use crate::{
basebackup::send_basebackup_tarball,
context::{DownloadBehavior, RequestContext},
metrics::{
BASEBACKUP_CACHE_ENTRIES, BASEBACKUP_CACHE_PREPARE, BASEBACKUP_CACHE_READ,
BASEBACKUP_CACHE_SIZE,
BASEBACKUP_CACHE_ENTRIES, BASEBACKUP_CACHE_PREPARE, BASEBACKUP_CACHE_PREPARE_QUEUE_SIZE,
BASEBACKUP_CACHE_READ, BASEBACKUP_CACHE_SIZE,
},
task_mgr::TaskKind,
tenant::{
@@ -35,8 +35,8 @@ pub struct BasebackupPrepareRequest {
pub lsn: Lsn,
}
pub type BasebackupPrepareSender = UnboundedSender<BasebackupPrepareRequest>;
pub type BasebackupPrepareReceiver = UnboundedReceiver<BasebackupPrepareRequest>;
pub type BasebackupPrepareSender = Sender<BasebackupPrepareRequest>;
pub type BasebackupPrepareReceiver = Receiver<BasebackupPrepareRequest>;
#[derive(Clone)]
struct CacheEntry {
@@ -60,40 +60,65 @@ struct CacheEntry {
/// and ~1 RPS for get requests.
pub struct BasebackupCache {
data_dir: Utf8PathBuf,
config: Option<BasebackupCacheConfig>,
entries: std::sync::Mutex<HashMap<TenantTimelineId, CacheEntry>>,
prepare_sender: BasebackupPrepareSender,
read_hit_count: GenericCounter<AtomicU64>,
read_miss_count: GenericCounter<AtomicU64>,
read_err_count: GenericCounter<AtomicU64>,
prepare_skip_count: GenericCounter<AtomicU64>,
}
impl BasebackupCache {
/// Creates a BasebackupCache and spawns the background task.
/// The initialization of the cache is performed in the background and does not
/// block the caller. The cache will return `None` for any get requests until
/// initialization is complete.
pub fn spawn(
runtime_handle: &tokio::runtime::Handle,
/// Create a new BasebackupCache instance.
/// Also returns a BasebackupPrepareReceiver which is needed to start
/// the background task.
/// The cache is initialized from the data_dir in the background task.
/// The cache will return `None` for any get requests until the initialization is complete.
/// The background task is spawned separately using [`Self::spawn_background_task`]
/// to avoid a circular dependency between the cache and the tenant manager.
pub fn new(
data_dir: Utf8PathBuf,
config: Option<BasebackupCacheConfig>,
prepare_receiver: BasebackupPrepareReceiver,
tenant_manager: Arc<TenantManager>,
cancel: CancellationToken,
) -> Arc<Self> {
) -> (Arc<Self>, BasebackupPrepareReceiver) {
let chan_size = config.as_ref().map(|c| c.max_size_entries).unwrap_or(1);
let (prepare_sender, prepare_receiver) = tokio::sync::mpsc::channel(chan_size);
let cache = Arc::new(BasebackupCache {
data_dir,
config,
entries: std::sync::Mutex::new(HashMap::new()),
prepare_sender,
read_hit_count: BASEBACKUP_CACHE_READ.with_label_values(&["hit"]),
read_miss_count: BASEBACKUP_CACHE_READ.with_label_values(&["miss"]),
read_err_count: BASEBACKUP_CACHE_READ.with_label_values(&["error"]),
prepare_skip_count: BASEBACKUP_CACHE_PREPARE.with_label_values(&["skip"]),
});
if let Some(config) = config {
(cache, prepare_receiver)
}
/// Spawns the background task.
/// The background task initializes the cache from the disk,
/// processes prepare requests, and cleans up outdated cache entries.
/// Noop if the cache is disabled (config is None).
pub fn spawn_background_task(
self: Arc<Self>,
runtime_handle: &tokio::runtime::Handle,
prepare_receiver: BasebackupPrepareReceiver,
tenant_manager: Arc<TenantManager>,
cancel: CancellationToken,
) {
if let Some(config) = self.config.clone() {
let background = BackgroundTask {
c: cache.clone(),
c: self,
config,
tenant_manager,
@@ -108,8 +133,45 @@ impl BasebackupCache {
};
runtime_handle.spawn(background.run(prepare_receiver));
}
}
cache
/// Send a basebackup prepare request to the background task.
/// The basebackup will be prepared asynchronously, it does not block the caller.
/// The request will be skipped if any cache limits are exceeded.
pub fn send_prepare(&self, tenant_shard_id: TenantShardId, timeline_id: TimelineId, lsn: Lsn) {
let req = BasebackupPrepareRequest {
tenant_shard_id,
timeline_id,
lsn,
};
BASEBACKUP_CACHE_PREPARE_QUEUE_SIZE.inc();
let res = self.prepare_sender.try_send(req);
if let Err(e) = res {
BASEBACKUP_CACHE_PREPARE_QUEUE_SIZE.dec();
self.prepare_skip_count.inc();
match e {
TrySendError::Full(_) => {
// Basebackup prepares are pretty rare, normally we should not hit this.
tracing::info!(
tenant_id = %tenant_shard_id.tenant_id,
%timeline_id,
%lsn,
"Basebackup prepare channel is full, skipping the request"
);
}
TrySendError::Closed(_) => {
// Normal during shutdown, not critical.
tracing::info!(
tenant_id = %tenant_shard_id.tenant_id,
%timeline_id,
%lsn,
"Basebackup prepare channel is closed, skipping the request"
);
}
}
}
}
/// Gets a basebackup entry from the cache.
@@ -122,6 +184,10 @@ impl BasebackupCache {
timeline_id: TimelineId,
lsn: Lsn,
) -> Option<tokio::fs::File> {
if !self.is_enabled() {
return None;
}
// Fast path. Check if the entry exists using the in-memory state.
let tti = TenantTimelineId::new(tenant_id, timeline_id);
if self.entries.lock().unwrap().get(&tti).map(|e| e.lsn) != Some(lsn) {
@@ -149,6 +215,10 @@ impl BasebackupCache {
}
}
pub fn is_enabled(&self) -> bool {
self.config.is_some()
}
// Private methods.
fn entry_filename(tenant_id: TenantId, timeline_id: TimelineId, lsn: Lsn) -> String {
@@ -366,6 +436,7 @@ impl BackgroundTask {
loop {
tokio::select! {
Some(req) = prepare_receiver.recv() => {
BASEBACKUP_CACHE_PREPARE_QUEUE_SIZE.dec();
if let Err(err) = self.prepare_basebackup(
req.tenant_shard_id,
req.timeline_id,

View File

@@ -569,8 +569,10 @@ fn start_pageserver(
pageserver::l0_flush::L0FlushGlobalState::new(conf.l0_flush.clone());
// Scan the local 'tenants/' directory and start loading the tenants
let (basebackup_prepare_sender, basebackup_prepare_receiver) =
tokio::sync::mpsc::unbounded_channel();
let (basebackup_cache, basebackup_prepare_receiver) = BasebackupCache::new(
conf.basebackup_cache_dir(),
conf.basebackup_cache_config.clone(),
);
let deletion_queue_client = deletion_queue.new_client();
let background_purges = mgr::BackgroundPurges::default();
@@ -582,7 +584,7 @@ fn start_pageserver(
remote_storage: remote_storage.clone(),
deletion_queue_client,
l0_flush_global_state,
basebackup_prepare_sender,
basebackup_cache: Arc::clone(&basebackup_cache),
feature_resolver: feature_resolver.clone(),
},
shutdown_pageserver.clone(),
@@ -590,10 +592,8 @@ fn start_pageserver(
let tenant_manager = Arc::new(tenant_manager);
BACKGROUND_RUNTIME.block_on(mgr::init_tenant_mgr(tenant_manager.clone(), order))?;
let basebackup_cache = BasebackupCache::spawn(
basebackup_cache.spawn_background_task(
BACKGROUND_RUNTIME.handle(),
conf.basebackup_cache_dir(),
conf.basebackup_cache_config.clone(),
basebackup_prepare_receiver,
Arc::clone(&tenant_manager),
shutdown_pageserver.child_token(),
@@ -806,7 +806,6 @@ fn start_pageserver(
} else {
None
},
basebackup_cache,
);
// Spawn a Pageserver gRPC server task. It will spawn separate tasks for

View File

@@ -37,7 +37,7 @@ async fn main() -> anyhow::Result<()> {
not_modified_since: Lsn(23),
},
batch_key: 42,
message: format!("message {}", msg),
message: format!("message {msg}"),
}));
let Ok(res) = tokio::time::timeout(Duration::from_secs(10), fut).await else {
eprintln!("pipe seems full");

View File

@@ -781,4 +781,21 @@ mod tests {
PageServerConf::parse_and_validate(NodeId(0), config_toml, &workdir)
.expect("parse_and_validate");
}
#[test]
fn test_config_posthog_incomplete_config_is_valid() {
let input = r#"
control_plane_api = "http://localhost:6666"
[posthog_config]
server_api_key = "phs_AAA"
private_api_url = "https://us.posthog.com"
public_api_url = "https://us.i.posthog.com"
"#;
let config_toml = toml_edit::de::from_str::<pageserver_api::config::ConfigToml>(input)
.expect("posthogconfig is valid");
let workdir = Utf8PathBuf::from("/nonexistent");
PageServerConf::parse_and_validate(NodeId(0), config_toml, &workdir)
.expect("parse_and_validate");
}
}

View File

@@ -3,7 +3,7 @@ use std::{collections::HashMap, sync::Arc, time::Duration};
use arc_swap::ArcSwap;
use pageserver_api::config::NodeMetadata;
use posthog_client_lite::{
CaptureEvent, FeatureResolverBackgroundLoop, PostHogClientConfig, PostHogEvaluationError,
CaptureEvent, FeatureResolverBackgroundLoop, PostHogEvaluationError,
PostHogFlagFilterPropertyValue,
};
use remote_storage::RemoteStorageKind;
@@ -45,16 +45,24 @@ impl FeatureResolver {
) -> anyhow::Result<Self> {
// DO NOT block in this function: make it return as fast as possible to avoid startup delays.
if let Some(posthog_config) = &conf.posthog_config {
let inner = FeatureResolverBackgroundLoop::new(
PostHogClientConfig {
server_api_key: posthog_config.server_api_key.clone(),
client_api_key: posthog_config.client_api_key.clone(),
project_id: posthog_config.project_id.clone(),
private_api_url: posthog_config.private_api_url.clone(),
public_api_url: posthog_config.public_api_url.clone(),
},
shutdown_pageserver,
);
let posthog_client_config = match posthog_config.clone().try_into_posthog_config() {
Ok(config) => config,
Err(e) => {
tracing::warn!(
"invalid posthog config, skipping posthog integration: {}",
e
);
return Ok(FeatureResolver {
inner: None,
internal_properties: None,
force_overrides_for_testing: Arc::new(ArcSwap::new(Arc::new(
HashMap::new(),
))),
});
}
};
let inner =
FeatureResolverBackgroundLoop::new(posthog_client_config, shutdown_pageserver);
let inner = Arc::new(inner);
// The properties shared by all tenants on this pageserver.

View File

@@ -1893,9 +1893,13 @@ async fn update_tenant_config_handler(
let location_conf = LocationConf::attached_single(
new_tenant_conf.clone(),
tenant.get_generation(),
&ShardParameters::default(),
ShardParameters::from(tenant.get_shard_identity()),
);
tenant
.get_shard_identity()
.assert_equal(location_conf.shard); // not strictly necessary since we construct it above
crate::tenant::TenantShard::persist_tenant_config(state.conf, &tenant_shard_id, &location_conf)
.await
.map_err(|e| ApiError::InternalServerError(anyhow::anyhow!(e)))?;
@@ -1937,9 +1941,13 @@ async fn patch_tenant_config_handler(
let location_conf = LocationConf::attached_single(
updated,
tenant.get_generation(),
&ShardParameters::default(),
ShardParameters::from(tenant.get_shard_identity()),
);
tenant
.get_shard_identity()
.assert_equal(location_conf.shard); // not strictly necessary since we construct it above
crate::tenant::TenantShard::persist_tenant_config(state.conf, &tenant_shard_id, &location_conf)
.await
.map_err(|e| ApiError::InternalServerError(anyhow::anyhow!(e)))?;

View File

@@ -4439,6 +4439,14 @@ pub(crate) static BASEBACKUP_CACHE_SIZE: Lazy<UIntGauge> = Lazy::new(|| {
.expect("failed to define a metric")
});
pub(crate) static BASEBACKUP_CACHE_PREPARE_QUEUE_SIZE: Lazy<UIntGauge> = Lazy::new(|| {
register_uint_gauge!(
"pageserver_basebackup_cache_prepare_queue_size",
"Number of requests in the basebackup prepare channel"
)
.expect("failed to define a metric")
});
static PAGESERVER_CONFIG_IGNORED_ITEMS: Lazy<UIntGaugeVec> = Lazy::new(|| {
register_uint_gauge_vec!(
"pageserver_config_ignored_items",

View File

@@ -12,8 +12,9 @@ use std::task::{Context, Poll};
use std::time::{Duration, Instant, SystemTime};
use std::{io, str};
use anyhow::{Context as _, anyhow, bail};
use anyhow::{Context as _, bail};
use bytes::{Buf as _, BufMut as _, BytesMut};
use chrono::Utc;
use futures::future::BoxFuture;
use futures::{FutureExt, Stream};
use itertools::Itertools;
@@ -62,7 +63,6 @@ use utils::{failpoint_support, span_record};
use crate::auth::check_permission;
use crate::basebackup::{self, BasebackupError};
use crate::basebackup_cache::BasebackupCache;
use crate::config::PageServerConf;
use crate::context::{
DownloadBehavior, PerfInstrumentFutureExt, RequestContext, RequestContextBuilder,
@@ -137,7 +137,6 @@ pub fn spawn(
perf_trace_dispatch: Option<Dispatch>,
tcp_listener: tokio::net::TcpListener,
tls_config: Option<Arc<rustls::ServerConfig>>,
basebackup_cache: Arc<BasebackupCache>,
) -> Listener {
let cancel = CancellationToken::new();
let libpq_ctx = RequestContext::todo_child(
@@ -159,7 +158,6 @@ pub fn spawn(
conf.pg_auth_type,
tls_config,
conf.page_service_pipelining.clone(),
basebackup_cache,
libpq_ctx,
cancel.clone(),
)
@@ -218,7 +216,6 @@ pub async fn libpq_listener_main(
auth_type: AuthType,
tls_config: Option<Arc<rustls::ServerConfig>>,
pipelining_config: PageServicePipeliningConfig,
basebackup_cache: Arc<BasebackupCache>,
listener_ctx: RequestContext,
listener_cancel: CancellationToken,
) -> Connections {
@@ -262,7 +259,6 @@ pub async fn libpq_listener_main(
auth_type,
tls_config.clone(),
pipelining_config.clone(),
Arc::clone(&basebackup_cache),
connection_ctx,
connections_cancel.child_token(),
gate_guard,
@@ -305,7 +301,6 @@ async fn page_service_conn_main(
auth_type: AuthType,
tls_config: Option<Arc<rustls::ServerConfig>>,
pipelining_config: PageServicePipeliningConfig,
basebackup_cache: Arc<BasebackupCache>,
connection_ctx: RequestContext,
cancel: CancellationToken,
gate_guard: GateGuard,
@@ -371,7 +366,6 @@ async fn page_service_conn_main(
pipelining_config,
conf.get_vectored_concurrent_io,
perf_span_fields,
basebackup_cache,
connection_ctx,
cancel.clone(),
gate_guard,
@@ -425,8 +419,6 @@ struct PageServerHandler {
pipelining_config: PageServicePipeliningConfig,
get_vectored_concurrent_io: GetVectoredConcurrentIo,
basebackup_cache: Arc<BasebackupCache>,
gate_guard: GateGuard,
}
@@ -912,7 +904,6 @@ impl PageServerHandler {
pipelining_config: PageServicePipeliningConfig,
get_vectored_concurrent_io: GetVectoredConcurrentIo,
perf_span_fields: ConnectionPerfSpanFields,
basebackup_cache: Arc<BasebackupCache>,
connection_ctx: RequestContext,
cancel: CancellationToken,
gate_guard: GateGuard,
@@ -926,7 +917,6 @@ impl PageServerHandler {
cancel,
pipelining_config,
get_vectored_concurrent_io,
basebackup_cache,
gate_guard,
}
}
@@ -2619,20 +2609,9 @@ impl PageServerHandler {
} else {
let mut writer = BufWriter::new(pgb.copyout_writer());
let cached = {
// Basebackup is cached only for this combination of parameters.
if timeline.is_basebackup_cache_enabled()
&& gzip
&& lsn.is_some()
&& prev_lsn.is_none()
{
self.basebackup_cache
.get(tenant_id, timeline_id, lsn.unwrap())
.await
} else {
None
}
};
let cached = timeline
.get_cached_basebackup_if_enabled(lsn, prev_lsn, full_backup, replica, gzip)
.await;
if let Some(mut cached) = cached {
from_cache = true;
@@ -3568,21 +3547,41 @@ impl proto::PageService for GrpcPageServiceHandler {
page_api::BaseBackupCompression::Gzip => Some(async_compression::Level::Fastest),
};
let result = basebackup::send_basebackup_tarball(
&mut simplex_write,
&timeline,
req.lsn,
None,
req.full,
req.replica,
gzip_level,
&ctx,
)
.instrument(span) // propagate request span
.await;
simplex_write.shutdown().await.map_err(|err| {
BasebackupError::Server(anyhow!("simplex shutdown failed: {err}"))
})?;
// Check for a cached basebackup.
let cached = timeline
.get_cached_basebackup_if_enabled(
req.lsn,
None,
req.full,
req.replica,
gzip_level.is_some(),
)
.await;
let result = if let Some(mut cached) = cached {
// If we have a cached basebackup, send it.
tokio::io::copy(&mut cached, &mut simplex_write)
.await
.map(|_| ())
.map_err(|err| BasebackupError::Client(err, "cached,copy"))
} else {
basebackup::send_basebackup_tarball(
&mut simplex_write,
&timeline,
req.lsn,
None,
req.full,
req.replica,
gzip_level,
&ctx,
)
.instrument(span) // propagate request span
.await
};
simplex_write
.shutdown()
.await
.map_err(|err| BasebackupError::Client(err, "simplex_write"))?;
result
});
@@ -3762,6 +3761,36 @@ impl proto::PageService for GrpcPageServiceHandler {
let resp: page_api::GetSlruSegmentResponse = resp.segment;
Ok(tonic::Response::new(resp.into()))
}
#[instrument(skip_all, fields(lsn))]
async fn lease_lsn(
&self,
req: tonic::Request<proto::LeaseLsnRequest>,
) -> Result<tonic::Response<proto::LeaseLsnResponse>, tonic::Status> {
let timeline = self.get_request_timeline(&req).await?;
let ctx = self.ctx.with_scope_timeline(&timeline);
// Validate and convert the request, and decorate the span.
let req: page_api::LeaseLsnRequest = req.into_inner().try_into()?;
span_record!(lsn=%req.lsn);
// Attempt to acquire a lease. Return FailedPrecondition if the lease could not be granted.
let lease_length = timeline.get_lsn_lease_length();
let expires = match timeline.renew_lsn_lease(req.lsn, lease_length, &ctx) {
Ok(lease) => lease.valid_until,
Err(err) => return Err(tonic::Status::failed_precondition(format!("{err}"))),
};
// TODO: is this spammy? Move it compute-side?
info!(
"acquired lease for {} until {}",
req.lsn,
chrono::DateTime::<Utc>::from(expires).to_rfc3339()
);
Ok(tonic::Response::new(expires.into()))
}
}
/// gRPC middleware layer that handles observability concerns:

View File

@@ -3015,7 +3015,7 @@ mod tests {
// This shard will get the even blocks
let shard = ShardIdentity::from_params(
ShardNumber(0),
&ShardParameters {
ShardParameters {
count: ShardCount(2),
stripe_size: ShardStripeSize(1),
},

View File

@@ -80,7 +80,7 @@ use self::timeline::uninit::{TimelineCreateGuard, TimelineExclusionError, Uninit
use self::timeline::{
EvictionTaskTenantState, GcCutoffs, TimelineDeleteProgress, TimelineResources, WaitLsnError,
};
use crate::basebackup_cache::BasebackupPrepareSender;
use crate::basebackup_cache::BasebackupCache;
use crate::config::PageServerConf;
use crate::context;
use crate::context::RequestContextBuilder;
@@ -162,7 +162,7 @@ pub struct TenantSharedResources {
pub remote_storage: GenericRemoteStorage,
pub deletion_queue_client: DeletionQueueClient,
pub l0_flush_global_state: L0FlushGlobalState,
pub basebackup_prepare_sender: BasebackupPrepareSender,
pub basebackup_cache: Arc<BasebackupCache>,
pub feature_resolver: FeatureResolver,
}
@@ -331,7 +331,7 @@ pub struct TenantShard {
deletion_queue_client: DeletionQueueClient,
/// A channel to send async requests to prepare a basebackup for the basebackup cache.
basebackup_prepare_sender: BasebackupPrepareSender,
basebackup_cache: Arc<BasebackupCache>,
/// Cached logical sizes updated updated on each [`TenantShard::gather_size_inputs`].
cached_logical_sizes: tokio::sync::Mutex<HashMap<(TimelineId, Lsn), u64>>,
@@ -1363,7 +1363,7 @@ impl TenantShard {
remote_storage,
deletion_queue_client,
l0_flush_global_state,
basebackup_prepare_sender,
basebackup_cache,
feature_resolver,
} = resources;
@@ -1380,7 +1380,7 @@ impl TenantShard {
remote_storage.clone(),
deletion_queue_client,
l0_flush_global_state,
basebackup_prepare_sender,
basebackup_cache,
feature_resolver,
));
@@ -3872,6 +3872,10 @@ impl TenantShard {
&self.tenant_shard_id
}
pub(crate) fn get_shard_identity(&self) -> ShardIdentity {
self.shard_identity
}
pub(crate) fn get_shard_stripe_size(&self) -> ShardStripeSize {
self.shard_identity.stripe_size
}
@@ -4380,7 +4384,7 @@ impl TenantShard {
remote_storage: GenericRemoteStorage,
deletion_queue_client: DeletionQueueClient,
l0_flush_global_state: L0FlushGlobalState,
basebackup_prepare_sender: BasebackupPrepareSender,
basebackup_cache: Arc<BasebackupCache>,
feature_resolver: FeatureResolver,
) -> TenantShard {
assert!(!attached_conf.location.generation.is_none());
@@ -4485,7 +4489,7 @@ impl TenantShard {
ongoing_timeline_detach: std::sync::Mutex::default(),
gc_block: Default::default(),
l0_flush_global_state,
basebackup_prepare_sender,
basebackup_cache,
feature_resolver,
}
}
@@ -4525,6 +4529,10 @@ impl TenantShard {
Ok(toml_edit::de::from_str::<LocationConf>(&config)?)
}
/// Stores a tenant location config to disk.
///
/// NB: make sure to call `ShardIdentity::assert_equal` before persisting a new config, to avoid
/// changes to shard parameters that may result in data corruption.
#[tracing::instrument(skip_all, fields(tenant_id=%tenant_shard_id.tenant_id, shard_id=%tenant_shard_id.shard_slug()))]
pub(super) async fn persist_tenant_config(
conf: &'static PageServerConf,
@@ -5414,7 +5422,7 @@ impl TenantShard {
pagestream_throttle_metrics: self.pagestream_throttle_metrics.clone(),
l0_compaction_trigger: self.l0_compaction_trigger.clone(),
l0_flush_global_state: self.l0_flush_global_state.clone(),
basebackup_prepare_sender: self.basebackup_prepare_sender.clone(),
basebackup_cache: self.basebackup_cache.clone(),
feature_resolver: self.feature_resolver.clone(),
}
}
@@ -6000,7 +6008,7 @@ pub(crate) mod harness {
) -> anyhow::Result<Arc<TenantShard>> {
let walredo_mgr = Arc::new(WalRedoManager::from(TestRedoManager));
let (basebackup_requst_sender, _) = tokio::sync::mpsc::unbounded_channel();
let (basebackup_cache, _) = BasebackupCache::new(Utf8PathBuf::new(), None);
let tenant = Arc::new(TenantShard::new(
TenantState::Attaching,
@@ -6008,7 +6016,7 @@ pub(crate) mod harness {
AttachedTenantConf::try_from(LocationConf::attached_single(
self.tenant_conf.clone(),
self.generation,
&ShardParameters::default(),
ShardParameters::default(),
))
.unwrap(),
self.shard_identity,
@@ -6018,7 +6026,7 @@ pub(crate) mod harness {
self.deletion_queue.new_client(),
// TODO: ideally we should run all unit tests with both configs
L0FlushGlobalState::new(L0FlushConfig::default()),
basebackup_requst_sender,
basebackup_cache,
FeatureResolver::new_disabled(),
));
@@ -11429,11 +11437,11 @@ mod tests {
if left != right {
eprintln!("---LEFT---");
for left in left.iter() {
eprintln!("{}", left);
eprintln!("{left}");
}
eprintln!("---RIGHT---");
for right in right.iter() {
eprintln!("{}", right);
eprintln!("{right}");
}
assert_eq!(left, right);
}

View File

@@ -12,6 +12,7 @@
use pageserver_api::models;
use pageserver_api::shard::{ShardCount, ShardIdentity, ShardNumber, ShardStripeSize};
use serde::{Deserialize, Serialize};
use utils::critical;
use utils::generation::Generation;
#[derive(Debug, Copy, Clone, Serialize, Deserialize, PartialEq, Eq)]
@@ -136,7 +137,7 @@ impl LocationConf {
pub(crate) fn attached_single(
tenant_conf: pageserver_api::models::TenantConfig,
generation: Generation,
shard_params: &models::ShardParameters,
shard_params: models::ShardParameters,
) -> Self {
Self {
mode: LocationMode::Attached(AttachedLocationConfig {
@@ -171,6 +172,16 @@ impl LocationConf {
}
}
// This should never happen.
// TODO: turn this into a proper assertion.
if stripe_size != self.shard.stripe_size {
critical!(
"stripe size mismatch: {} != {}",
self.shard.stripe_size,
stripe_size,
);
}
self.shard.stripe_size = stripe_size;
}

View File

@@ -880,6 +880,9 @@ impl TenantManager {
// phase of writing config and/or waiting for flush, before returning.
match fast_path_taken {
Some(FastPathModified::Attached(tenant)) => {
tenant
.shard_identity
.assert_equal(new_location_config.shard);
TenantShard::persist_tenant_config(
self.conf,
&tenant_shard_id,
@@ -914,7 +917,10 @@ impl TenantManager {
return Ok(Some(tenant));
}
Some(FastPathModified::Secondary(_secondary_tenant)) => {
Some(FastPathModified::Secondary(secondary_tenant)) => {
secondary_tenant
.shard_identity
.assert_equal(new_location_config.shard);
TenantShard::persist_tenant_config(
self.conf,
&tenant_shard_id,
@@ -948,6 +954,10 @@ impl TenantManager {
match slot_guard.get_old_value() {
Some(TenantSlot::Attached(tenant)) => {
tenant
.shard_identity
.assert_equal(new_location_config.shard);
// The case where we keep a Tenant alive was covered above in the special case
// for Attached->Attached transitions in the same generation. By this point,
// if we see an attached tenant we know it will be discarded and should be
@@ -981,9 +991,13 @@ impl TenantManager {
// rather than assuming it to be empty.
spawn_mode = SpawnMode::Eager;
}
Some(TenantSlot::Secondary(state)) => {
Some(TenantSlot::Secondary(secondary_tenant)) => {
secondary_tenant
.shard_identity
.assert_equal(new_location_config.shard);
info!("Shutting down secondary tenant");
state.shutdown().await;
secondary_tenant.shutdown().await;
}
Some(TenantSlot::InProgress(_)) => {
// This should never happen: acquire_slot should error out
@@ -2200,7 +2214,7 @@ impl TenantManager {
selector: ShardSelector,
) -> ShardResolveResult {
let tenants = self.tenants.read().unwrap();
let mut want_shard = None;
let mut want_shard: Option<ShardIndex> = None;
let mut any_in_progress = None;
match &*tenants {
@@ -2225,14 +2239,23 @@ impl TenantManager {
return ShardResolveResult::Found(tenant.clone());
}
ShardSelector::Page(key) => {
// First slot we see for this tenant, calculate the expected shard number
// for the key: we will use this for checking if this and subsequent
// slots contain the key, rather than recalculating the hash each time.
if want_shard.is_none() {
want_shard = Some(tenant.shard_identity.get_shard_number(&key));
// Each time we find an attached slot with a different shard count,
// recompute the expected shard number: during shard splits we might
// have multiple shards with the old shard count.
if want_shard.is_none()
|| want_shard.unwrap().shard_count != tenant.shard_identity.count
{
want_shard = Some(ShardIndex {
shard_number: tenant.shard_identity.get_shard_number(&key),
shard_count: tenant.shard_identity.count,
});
}
if Some(tenant.shard_identity.number) == want_shard {
if Some(ShardIndex {
shard_number: tenant.shard_identity.number,
shard_count: tenant.shard_identity.count,
}) == want_shard
{
return ShardResolveResult::Found(tenant.clone());
}
}
@@ -2891,14 +2914,18 @@ mod tests {
use std::collections::BTreeMap;
use std::sync::Arc;
use camino::Utf8PathBuf;
use storage_broker::BrokerClientChannel;
use tracing::Instrument;
use super::super::harness::TenantHarness;
use super::TenantsMap;
use crate::tenant::{
TenantSharedResources,
mgr::{BackgroundPurges, TenantManager, TenantSlot},
use crate::{
basebackup_cache::BasebackupCache,
tenant::{
TenantSharedResources,
mgr::{BackgroundPurges, TenantManager, TenantSlot},
},
};
#[tokio::test(start_paused = true)]
@@ -2924,9 +2951,7 @@ mod tests {
// Invoke remove_tenant_from_memory with a cleanup hook that blocks until we manually
// permit it to proceed: that will stick the tenant in InProgress
let (basebackup_prepare_sender, _) = tokio::sync::mpsc::unbounded_channel::<
crate::basebackup_cache::BasebackupPrepareRequest,
>();
let (basebackup_cache, _) = BasebackupCache::new(Utf8PathBuf::new(), None);
let tenant_manager = TenantManager {
tenants: std::sync::RwLock::new(TenantsMap::Open(tenants)),
@@ -2940,7 +2965,7 @@ mod tests {
l0_flush_global_state: crate::l0_flush::L0FlushGlobalState::new(
h.conf.l0_flush.clone(),
),
basebackup_prepare_sender,
basebackup_cache,
feature_resolver: crate::feature_resolver::FeatureResolver::new_disabled(),
},
cancel: tokio_util::sync::CancellationToken::new(),

View File

@@ -101,7 +101,7 @@ pub(crate) struct SecondaryTenant {
// Secondary mode does not need the full shard identity or the pageserver_api::models::TenantConfig. However,
// storing these enables us to report our full LocationConf, enabling convenient reconciliation
// by the control plane (see [`Self::get_location_conf`])
shard_identity: ShardIdentity,
pub(crate) shard_identity: ShardIdentity,
tenant_conf: std::sync::Mutex<pageserver_api::models::TenantConfig>,
// Internal state used by the Downloader.

View File

@@ -55,11 +55,11 @@ pub struct BatchLayerWriter {
}
impl BatchLayerWriter {
pub async fn new(conf: &'static PageServerConf) -> anyhow::Result<Self> {
Ok(Self {
pub fn new(conf: &'static PageServerConf) -> Self {
Self {
generated_layer_writers: Vec::new(),
conf,
})
}
}
pub fn add_unfinished_image_writer(
@@ -182,7 +182,7 @@ impl BatchLayerWriter {
/// An image writer that takes images and produces multiple image layers.
#[must_use]
pub struct SplitImageLayerWriter<'a> {
inner: ImageLayerWriter,
inner: Option<ImageLayerWriter>,
target_layer_size: u64,
lsn: Lsn,
conf: &'static PageServerConf,
@@ -196,7 +196,7 @@ pub struct SplitImageLayerWriter<'a> {
impl<'a> SplitImageLayerWriter<'a> {
#[allow(clippy::too_many_arguments)]
pub async fn new(
pub fn new(
conf: &'static PageServerConf,
timeline_id: TimelineId,
tenant_shard_id: TenantShardId,
@@ -205,30 +205,19 @@ impl<'a> SplitImageLayerWriter<'a> {
target_layer_size: u64,
gate: &'a utils::sync::gate::Gate,
cancel: CancellationToken,
ctx: &RequestContext,
) -> anyhow::Result<Self> {
Ok(Self {
) -> Self {
Self {
target_layer_size,
inner: ImageLayerWriter::new(
conf,
timeline_id,
tenant_shard_id,
&(start_key..Key::MAX),
lsn,
gate,
cancel.clone(),
ctx,
)
.await?,
inner: None,
conf,
timeline_id,
tenant_shard_id,
batches: BatchLayerWriter::new(conf).await?,
batches: BatchLayerWriter::new(conf),
lsn,
start_key,
gate,
cancel,
})
}
}
pub async fn put_image(
@@ -237,12 +226,31 @@ impl<'a> SplitImageLayerWriter<'a> {
img: Bytes,
ctx: &RequestContext,
) -> Result<(), PutError> {
if self.inner.is_none() {
self.inner = Some(
ImageLayerWriter::new(
self.conf,
self.timeline_id,
self.tenant_shard_id,
&(self.start_key..Key::MAX),
self.lsn,
self.gate,
self.cancel.clone(),
ctx,
)
.await
.map_err(PutError::Other)?,
);
}
let inner = self.inner.as_mut().unwrap();
// The current estimation is an upper bound of the space that the key/image could take
// because we did not consider compression in this estimation. The resulting image layer
// could be smaller than the target size.
let addition_size_estimation = KEY_SIZE as u64 + img.len() as u64;
if self.inner.num_keys() >= 1
&& self.inner.estimated_size() + addition_size_estimation >= self.target_layer_size
if inner.num_keys() >= 1
&& inner.estimated_size() + addition_size_estimation >= self.target_layer_size
{
let next_image_writer = ImageLayerWriter::new(
self.conf,
@@ -256,7 +264,7 @@ impl<'a> SplitImageLayerWriter<'a> {
)
.await
.map_err(PutError::Other)?;
let prev_image_writer = std::mem::replace(&mut self.inner, next_image_writer);
let prev_image_writer = std::mem::replace(inner, next_image_writer);
self.batches.add_unfinished_image_writer(
prev_image_writer,
self.start_key..key,
@@ -264,7 +272,7 @@ impl<'a> SplitImageLayerWriter<'a> {
);
self.start_key = key;
}
self.inner.put_image(key, img, ctx).await
inner.put_image(key, img, ctx).await
}
pub(crate) async fn finish_with_discard_fn<D, F>(
@@ -281,8 +289,10 @@ impl<'a> SplitImageLayerWriter<'a> {
let Self {
mut batches, inner, ..
} = self;
if inner.num_keys() != 0 {
batches.add_unfinished_image_writer(inner, self.start_key..end_key, self.lsn);
if let Some(inner) = inner {
if inner.num_keys() != 0 {
batches.add_unfinished_image_writer(inner, self.start_key..end_key, self.lsn);
}
}
batches.finish_with_discard_fn(tline, ctx, discard_fn).await
}
@@ -319,7 +329,7 @@ pub struct SplitDeltaLayerWriter<'a> {
}
impl<'a> SplitDeltaLayerWriter<'a> {
pub async fn new(
pub fn new(
conf: &'static PageServerConf,
timeline_id: TimelineId,
tenant_shard_id: TenantShardId,
@@ -327,8 +337,8 @@ impl<'a> SplitDeltaLayerWriter<'a> {
target_layer_size: u64,
gate: &'a utils::sync::gate::Gate,
cancel: CancellationToken,
) -> anyhow::Result<Self> {
Ok(Self {
) -> Self {
Self {
target_layer_size,
inner: None,
conf,
@@ -336,10 +346,10 @@ impl<'a> SplitDeltaLayerWriter<'a> {
tenant_shard_id,
lsn_range,
last_key_written: Key::MIN,
batches: BatchLayerWriter::new(conf).await?,
batches: BatchLayerWriter::new(conf),
gate,
cancel,
})
}
}
pub async fn put_value(
@@ -497,10 +507,7 @@ mod tests {
4 * 1024 * 1024,
&tline.gate,
tline.cancel.clone(),
&ctx,
)
.await
.unwrap();
);
let mut delta_writer = SplitDeltaLayerWriter::new(
tenant.conf,
@@ -510,9 +517,7 @@ mod tests {
4 * 1024 * 1024,
&tline.gate,
tline.cancel.clone(),
)
.await
.unwrap();
);
image_writer
.put_image(get_key(0), get_img(0), &ctx)
@@ -578,10 +583,7 @@ mod tests {
4 * 1024 * 1024,
&tline.gate,
tline.cancel.clone(),
&ctx,
)
.await
.unwrap();
);
let mut delta_writer = SplitDeltaLayerWriter::new(
tenant.conf,
tline.timeline_id,
@@ -590,9 +592,7 @@ mod tests {
4 * 1024 * 1024,
&tline.gate,
tline.cancel.clone(),
)
.await
.unwrap();
);
const N: usize = 2000;
for i in 0..N {
let i = i as u32;
@@ -679,10 +679,7 @@ mod tests {
4 * 1024,
&tline.gate,
tline.cancel.clone(),
&ctx,
)
.await
.unwrap();
);
let mut delta_writer = SplitDeltaLayerWriter::new(
tenant.conf,
@@ -692,9 +689,7 @@ mod tests {
4 * 1024,
&tline.gate,
tline.cancel.clone(),
)
.await
.unwrap();
);
image_writer
.put_image(get_key(0), get_img(0), &ctx)
@@ -770,9 +765,7 @@ mod tests {
4 * 1024 * 1024,
&tline.gate,
tline.cancel.clone(),
)
.await
.unwrap();
);
for i in 0..N {
let i = i as u32;

View File

@@ -17,14 +17,17 @@ use tracing::*;
use utils::backoff::exponential_backoff_duration;
use utils::completion::Barrier;
use utils::pausable_failpoint;
use utils::sync::gate::GateError;
use crate::context::{DownloadBehavior, RequestContext};
use crate::metrics::{self, BackgroundLoopSemaphoreMetricsRecorder, TENANT_TASK_EVENTS};
use crate::task_mgr::{self, BACKGROUND_RUNTIME, TOKIO_WORKER_THREADS, TaskKind};
use crate::tenant::blob_io::WriteBlobError;
use crate::tenant::throttle::Stats;
use crate::tenant::timeline::CompactionError;
use crate::tenant::timeline::compaction::CompactionOutcome;
use crate::tenant::{TenantShard, TenantState};
use crate::virtual_file::owned_buffers_io::write::FlushTaskError;
/// Semaphore limiting concurrent background tasks (across all tenants).
///
@@ -313,7 +316,20 @@ pub(crate) fn log_compaction_error(
let timeline = root_cause
.downcast_ref::<PageReconstructError>()
.is_some_and(|e| e.is_stopping());
let is_stopping = upload_queue || timeline;
let buffered_writer_flush_task_canelled = root_cause
.downcast_ref::<FlushTaskError>()
.is_some_and(|e| e.is_cancel());
let write_blob_cancelled = root_cause
.downcast_ref::<WriteBlobError>()
.is_some_and(|e| e.is_cancel());
let gate_closed = root_cause
.downcast_ref::<GateError>()
.is_some_and(|e| e.is_cancel());
let is_stopping = upload_queue
|| timeline
|| buffered_writer_flush_task_canelled
|| write_blob_cancelled
|| gate_closed;
if is_stopping {
Level::INFO

View File

@@ -95,12 +95,12 @@ use super::storage_layer::{LayerFringe, LayerVisibilityHint, ReadableLayer};
use super::tasks::log_compaction_error;
use super::upload_queue::NotInitialized;
use super::{
AttachedTenantConf, BasebackupPrepareSender, GcError, HeatMapTimeline, MaybeOffloaded,
AttachedTenantConf, GcError, HeatMapTimeline, MaybeOffloaded,
debug_assert_current_span_has_tenant_and_timeline_id,
};
use crate::PERF_TRACE_TARGET;
use crate::aux_file::AuxFileSizeEstimator;
use crate::basebackup_cache::BasebackupPrepareRequest;
use crate::basebackup_cache::BasebackupCache;
use crate::config::PageServerConf;
use crate::context::{
DownloadBehavior, PerfInstrumentFutureExt, RequestContext, RequestContextBuilder,
@@ -201,7 +201,7 @@ pub struct TimelineResources {
pub pagestream_throttle_metrics: Arc<crate::metrics::tenant_throttling::Pagestream>,
pub l0_compaction_trigger: Arc<Notify>,
pub l0_flush_global_state: l0_flush::L0FlushGlobalState,
pub basebackup_prepare_sender: BasebackupPrepareSender,
pub basebackup_cache: Arc<BasebackupCache>,
pub feature_resolver: FeatureResolver,
}
@@ -448,7 +448,7 @@ pub struct Timeline {
wait_lsn_log_slow: tokio::sync::Semaphore,
/// A channel to send async requests to prepare a basebackup for the basebackup cache.
basebackup_prepare_sender: BasebackupPrepareSender,
basebackup_cache: Arc<BasebackupCache>,
feature_resolver: FeatureResolver,
}
@@ -763,7 +763,7 @@ pub(crate) enum CreateImageLayersError {
PageReconstructError(#[source] PageReconstructError),
#[error(transparent)]
Other(#[from] anyhow::Error),
Other(anyhow::Error),
}
impl From<layer_manager::Shutdown> for CreateImageLayersError {
@@ -2500,6 +2500,37 @@ impl Timeline {
.unwrap_or(self.conf.default_tenant_conf.basebackup_cache_enabled)
}
/// Try to get a basebackup from the on-disk cache.
pub(crate) async fn get_cached_basebackup(&self, lsn: Lsn) -> Option<tokio::fs::File> {
self.basebackup_cache
.get(self.tenant_shard_id.tenant_id, self.timeline_id, lsn)
.await
}
/// Convenience method to attempt fetching a basebackup for the timeline if enabled and safe for
/// the given request parameters.
///
/// TODO: consider moving this onto GrpcPageServiceHandler once the libpq handler is gone.
pub async fn get_cached_basebackup_if_enabled(
&self,
lsn: Option<Lsn>,
prev_lsn: Option<Lsn>,
full: bool,
replica: bool,
gzip: bool,
) -> Option<tokio::fs::File> {
if !self.is_basebackup_cache_enabled() || !self.basebackup_cache.is_enabled() {
return None;
}
// We have to know which LSN to fetch the basebackup for.
let lsn = lsn?;
// We only cache gzipped, non-full basebackups for primary computes with automatic prev_lsn.
if prev_lsn.is_some() || full || replica || !gzip {
return None;
}
self.get_cached_basebackup(lsn).await
}
/// Prepare basebackup for the given LSN and store it in the basebackup cache.
/// The method is asynchronous and returns immediately.
/// The actual basebackup preparation is performed in the background
@@ -2521,17 +2552,8 @@ impl Timeline {
return;
}
let res = self
.basebackup_prepare_sender
.send(BasebackupPrepareRequest {
tenant_shard_id: self.tenant_shard_id,
timeline_id: self.timeline_id,
lsn,
});
if let Err(e) = res {
// May happen during shutdown, it's not critical.
info!("Failed to send shutdown checkpoint: {e:#}");
}
self.basebackup_cache
.send_prepare(self.tenant_shard_id, self.timeline_id, lsn);
}
}
@@ -3088,7 +3110,7 @@ impl Timeline {
wait_lsn_log_slow: tokio::sync::Semaphore::new(1),
basebackup_prepare_sender: resources.basebackup_prepare_sender,
basebackup_cache: resources.basebackup_cache,
feature_resolver: resources.feature_resolver,
};
@@ -4658,6 +4680,16 @@ impl Timeline {
mut layer_flush_start_rx: tokio::sync::watch::Receiver<(u64, Lsn)>,
ctx: &RequestContext,
) {
// Always notify waiters about the flush loop exiting since the loop might stop
// when the timeline hasn't been cancelled.
let scopeguard_rx = layer_flush_start_rx.clone();
scopeguard::defer! {
let (flush_counter, _) = *scopeguard_rx.borrow();
let _ = self
.layer_flush_done_tx
.send_replace((flush_counter, Err(FlushLayerError::Cancelled)));
}
// Subscribe to L0 delta layer updates, for compaction backpressure.
let mut watch_l0 = match self
.layers
@@ -4687,9 +4719,6 @@ impl Timeline {
let result = loop {
if self.cancel.is_cancelled() {
info!("dropping out of flush loop for timeline shutdown");
// Note: we do not bother transmitting into [`layer_flush_done_tx`], because
// anyone waiting on that will respect self.cancel as well: they will stop
// waiting at the same time we as drop out of this loop.
return;
}
@@ -5561,7 +5590,7 @@ impl Timeline {
self.should_check_if_image_layers_required(lsn)
};
let mut batch_image_writer = BatchLayerWriter::new(self.conf).await?;
let mut batch_image_writer = BatchLayerWriter::new(self.conf);
let mut all_generated = true;
@@ -5665,7 +5694,8 @@ impl Timeline {
self.cancel.clone(),
ctx,
)
.await?;
.await
.map_err(CreateImageLayersError::Other)?;
fail_point!("image-layer-writer-fail-before-finish", |_| {
Err(CreateImageLayersError::Other(anyhow::anyhow!(
@@ -5760,7 +5790,10 @@ impl Timeline {
}
}
let image_layers = batch_image_writer.finish(self, ctx).await?;
let image_layers = batch_image_writer
.finish(self, ctx)
.await
.map_err(CreateImageLayersError::Other)?;
let mut guard = self.layers.write(LayerManagerLockHolder::Compaction).await;

View File

@@ -3503,22 +3503,16 @@ impl Timeline {
// Only create image layers when there is no ancestor branches. TODO: create covering image layer
// when some condition meet.
let mut image_layer_writer = if !has_data_below {
Some(
SplitImageLayerWriter::new(
self.conf,
self.timeline_id,
self.tenant_shard_id,
job_desc.compaction_key_range.start,
lowest_retain_lsn,
self.get_compaction_target_size(),
&self.gate,
self.cancel.clone(),
ctx,
)
.await
.context("failed to create image layer writer")
.map_err(CompactionError::Other)?,
)
Some(SplitImageLayerWriter::new(
self.conf,
self.timeline_id,
self.tenant_shard_id,
job_desc.compaction_key_range.start,
lowest_retain_lsn,
self.get_compaction_target_size(),
&self.gate,
self.cancel.clone(),
))
} else {
None
};
@@ -3531,10 +3525,7 @@ impl Timeline {
self.get_compaction_target_size(),
&self.gate,
self.cancel.clone(),
)
.await
.context("failed to create delta layer writer")
.map_err(CompactionError::Other)?;
);
#[derive(Default)]
struct RewritingLayers {
@@ -4330,7 +4321,8 @@ impl TimelineAdaptor {
self.timeline.cancel.clone(),
ctx,
)
.await?;
.await
.map_err(CreateImageLayersError::Other)?;
fail_point!("image-layer-writer-fail-before-finish", |_| {
Err(CreateImageLayersError::Other(anyhow::anyhow!(
@@ -4339,7 +4331,10 @@ impl TimelineAdaptor {
});
let keyspace = KeySpace {
ranges: self.get_keyspace(key_range, lsn, ctx).await?,
ranges: self
.get_keyspace(key_range, lsn, ctx)
.await
.map_err(CreateImageLayersError::Other)?,
};
// TODO set proper (stateful) start. The create_image_layer_for_rel_blocks function mostly
let outcome = self
@@ -4358,9 +4353,13 @@ impl TimelineAdaptor {
unfinished_image_layer,
} = outcome
{
let (desc, path) = unfinished_image_layer.finish(ctx).await?;
let (desc, path) = unfinished_image_layer
.finish(ctx)
.await
.map_err(CreateImageLayersError::Other)?;
let image_layer =
Layer::finish_creating(self.timeline.conf, &self.timeline, desc, &path)?;
Layer::finish_creating(self.timeline.conf, &self.timeline, desc, &path)
.map_err(CreateImageLayersError::Other)?;
self.new_images.push(image_layer);
}

View File

@@ -241,8 +241,17 @@ impl DeleteTimelineFlow {
{
Ok(r) => r,
Err(DownloadError::NotFound) => {
// Deletion is already complete
// Deletion is already complete.
// As we came here, we will need to remove the timeline from the tenant though.
tracing::info!("Timeline already deleted in remote storage");
if let TimelineOrOffloaded::Offloaded(_) = &timeline {
// We only supoprt this for offloaded timelines, as we don't know which state non-offloaded timelines are in.
tracing::info!(
"Timeline with gone index part is offloaded timeline. Removing from tenant."
);
remove_maybe_offloaded_timeline_from_tenant(tenant, &timeline, &guard)
.await?;
}
return Ok(());
}
Err(e) => {

View File

@@ -885,7 +885,7 @@ async fn remote_copy(
}
tracing::info!("Deleting orphan layer file to make way for hard linking");
// Delete orphan layer file and try again, to ensure this layer has a well understood source
std::fs::remove_file(adopted_path)
std::fs::remove_file(&adoptee_path)
.map_err(|e| Error::launder(e.into(), Error::Prepare))?;
std::fs::hard_link(adopted_path, &adoptee_path)
.map_err(|e| Error::launder(e.into(), Error::Prepare))?;

View File

@@ -887,7 +887,7 @@ mod tests {
.expect("we still have it");
}
fn make_relation_key_for_shard(shard: ShardNumber, params: &ShardParameters) -> Key {
fn make_relation_key_for_shard(shard: ShardNumber, params: ShardParameters) -> Key {
rel_block_to_key(
RelTag {
spcnode: 1663,
@@ -917,14 +917,14 @@ mod tests {
let child0 = Arc::new_cyclic(|myself| StubTimeline {
gate: Default::default(),
id: timeline_id,
shard: ShardIdentity::from_params(ShardNumber(0), &child_params),
shard: ShardIdentity::from_params(ShardNumber(0), child_params),
per_timeline_state: PerTimelineState::default(),
myself: myself.clone(),
});
let child1 = Arc::new_cyclic(|myself| StubTimeline {
gate: Default::default(),
id: timeline_id,
shard: ShardIdentity::from_params(ShardNumber(1), &child_params),
shard: ShardIdentity::from_params(ShardNumber(1), child_params),
per_timeline_state: PerTimelineState::default(),
myself: myself.clone(),
});
@@ -937,7 +937,7 @@ mod tests {
let handle = cache
.get(
timeline_id,
ShardSelector::Page(make_relation_key_for_shard(ShardNumber(i), &child_params)),
ShardSelector::Page(make_relation_key_for_shard(ShardNumber(i), child_params)),
&StubManager {
shards: vec![parent.clone()],
},
@@ -961,7 +961,7 @@ mod tests {
let handle = cache
.get(
timeline_id,
ShardSelector::Page(make_relation_key_for_shard(ShardNumber(i), &child_params)),
ShardSelector::Page(make_relation_key_for_shard(ShardNumber(i), child_params)),
&StubManager {
shards: vec![], // doesn't matter what's in here, the cache is fully loaded
},
@@ -978,7 +978,7 @@ mod tests {
let parent_handle = cache
.get(
timeline_id,
ShardSelector::Page(make_relation_key_for_shard(ShardNumber(0), &child_params)),
ShardSelector::Page(make_relation_key_for_shard(ShardNumber(0), child_params)),
&StubManager {
shards: vec![parent.clone()],
},
@@ -995,7 +995,7 @@ mod tests {
let handle = cache
.get(
timeline_id,
ShardSelector::Page(make_relation_key_for_shard(ShardNumber(i), &child_params)),
ShardSelector::Page(make_relation_key_for_shard(ShardNumber(i), child_params)),
&StubManager {
shards: vec![child0.clone(), child1.clone()], // <====== this changed compared to previous loop
},

121
postgres.mk Normal file
View File

@@ -0,0 +1,121 @@
# Sub-makefile for compiling PostgreSQL as part of Neon. This is
# included from the main Makefile, and is not meant to be called
# directly.
#
# CI workflows and Dockerfiles can take advantage of the following
# properties for caching:
#
# - Compiling the targets in this file only builds the PostgreSQL sources
# under the vendor/ subdirectory, nothing else from the repository.
# - All outputs go to POSTGRES_INSTALL_DIR (by default 'pg_install',
# see parent Makefile)
# - intermediate build artifacts go to BUILD_DIR
#
#
# Variables passed from the parent Makefile that control what gets
# installed and where:
# - POSTGRES_VERSIONS
# - POSTGRES_INSTALL_DIR
# - BUILD_DIR
#
# Variables passed from the parent Makefile that affect the build
# process and the resulting binaries:
# - PG_CONFIGURE_OPTS
# - PG_CFLAGS
# - PG_LDFLAGS
# - EXTRA_PATH_OVERRIDES
###
### Main targets
###
### These are called from the main Makefile, and can also be called
### directly from command line
# Compile and install a specific PostgreSQL version
postgres-install-%: postgres-configure-% \
postgres-headers-install-% # to prevent `make install` conflicts with neon's `postgres-headers`
# Install the PostgreSQL header files into $(POSTGRES_INSTALL_DIR)/<version>/include
#
# This is implicitly part of the 'postgres-install-%' target, but this can be handy
# if you want to install just the headers without building PostgreSQL, e.g. for building
# extensions.
postgres-headers-install-%: postgres-configure-%
+@echo "Installing PostgreSQL $* headers"
$(MAKE) -C $(BUILD_DIR)/$*/src/include MAKELEVEL=0 install
# Run Postgres regression tests
postgres-check-%: postgres-install-%
$(MAKE) -C $(BUILD_DIR)/$* MAKELEVEL=0 check
###
### Shorthands for the main targets, for convenience
###
# Same as the above main targets, but for all supported PostgreSQL versions
# For example, 'make postgres-install' is equivalent to
# 'make postgres-install-v14 postgres-install-v15 postgres-install-v16 postgres-install-v17'
all_version_targets=postgres-install postgres-headers-install postgres-check
.PHONY: $(all_version_targets)
$(all_version_targets): postgres-%: $(foreach pg_version,$(POSTGRES_VERSIONS),postgres-%-$(pg_version))
.PHONY: postgres
postgres: postgres-install
.PHONY: postgres-headers
postgres-headers: postgres-headers-install
# 'postgres-v17' is an alias for 'postgres-install-v17' etc.
$(foreach pg_version,$(POSTGRES_VERSIONS),postgres-$(pg_version)): postgres-%: postgres-install-%
###
### Intermediate targets
###
### These are not intended to be called directly, but are dependencies for the
### main targets.
# Run 'configure'
$(BUILD_DIR)/%/config.status:
mkdir -p $(BUILD_DIR)
test -e $(BUILD_DIR)/CACHEDIR.TAG || echo "$(CACHEDIR_TAG_CONTENTS)" > $(BUILD_DIR)/CACHEDIR.TAG
+@echo "Configuring Postgres $* build"
@test -s $(ROOT_PROJECT_DIR)/vendor/postgres-$*/configure || { \
echo "\nPostgres submodule not found in $(ROOT_PROJECT_DIR)/vendor/postgres-$*/, execute "; \
echo "'git submodule update --init --recursive --depth 2 --progress .' in project root.\n"; \
exit 1; }
mkdir -p $(BUILD_DIR)/$*
VERSION=$*; \
EXTRA_VERSION=$$(cd $(ROOT_PROJECT_DIR)/vendor/postgres-$$VERSION && git rev-parse HEAD); \
(cd $(BUILD_DIR)/$$VERSION && \
env PATH="$(EXTRA_PATH_OVERRIDES):$$PATH" $(ROOT_PROJECT_DIR)/vendor/postgres-$$VERSION/configure \
CFLAGS='$(PG_CFLAGS)' LDFLAGS='$(PG_LDFLAGS)' \
$(PG_CONFIGURE_OPTS) --with-extra-version=" ($$EXTRA_VERSION)" \
--prefix=$(abspath $(POSTGRES_INSTALL_DIR))/$$VERSION > configure.log)
# nicer alias to run 'configure'.
#
# This tries to accomplish this rule:
#
# postgres-configure-%: $(BUILD_DIR)/%/config.status
#
# XXX: I'm not sure why the above rule doesn't work directly. But this accomplishses
# the same thing
$(foreach pg_version,$(POSTGRES_VERSIONS),postgres-configure-$(pg_version)): postgres-configure-%: FORCE $(BUILD_DIR)/%/config.status
# Compile and install PostgreSQL (and a few contrib modules used in tests)
postgres-install-%: postgres-configure-% \
postgres-headers-install-% # to prevent `make install` conflicts with neon's `postgres-headers-install`
+@echo "Compiling PostgreSQL $*"
$(MAKE) -C $(BUILD_DIR)/$* MAKELEVEL=0 install
$(MAKE) -C $(BUILD_DIR)/$*/contrib/pg_prewarm install
$(MAKE) -C $(BUILD_DIR)/$*/contrib/pg_buffercache install
$(MAKE) -C $(BUILD_DIR)/$*/contrib/pg_visibility install
$(MAKE) -C $(BUILD_DIR)/$*/contrib/pageinspect install
$(MAKE) -C $(BUILD_DIR)/$*/contrib/pg_trgm install
$(MAKE) -C $(BUILD_DIR)/$*/contrib/amcheck install
$(MAKE) -C $(BUILD_DIR)/$*/contrib/test_decoding install
.PHONY: FORCE
FORCE:

View File

@@ -288,7 +288,6 @@ fn build_config(args: &LocalProxyCliArgs) -> anyhow::Result<&'static ProxyConfig
},
proxy_protocol_v2: config::ProxyProtocolV2::Rejected,
handshake_timeout: Duration::from_secs(10),
region: "local".into(),
wake_compute_retry_config: RetryConfig::parse(RetryConfig::WAKE_COMPUTE_DEFAULT_VALUES)?,
connect_compute_locks,
connect_to_compute: compute_config,

View File

@@ -26,9 +26,10 @@ use utils::sentry_init::init_sentry;
use crate::context::RequestContext;
use crate::metrics::{Metrics, ThreadPoolMetrics};
use crate::pglb::TlsRequired;
use crate::pqproto::FeStartupPacket;
use crate::protocol2::ConnectionInfo;
use crate::proxy::{ErrorSource, TlsRequired, copy_bidirectional_client_compute};
use crate::proxy::{ErrorSource, copy_bidirectional_client_compute};
use crate::stream::{PqStream, Stream};
use crate::util::run_until_cancelled;
@@ -236,7 +237,6 @@ pub(super) async fn task_main(
extra: None,
},
crate::metrics::Protocol::SniRouter,
"sni",
);
handle_client(ctx, dest_suffix, tls_config, compute_tls_config, socket).await
}

View File

@@ -154,12 +154,6 @@ struct ProxyCliArgs {
/// timeout for the TLS handshake
#[clap(long, default_value = "15s", value_parser = humantime::parse_duration)]
handshake_timeout: tokio::time::Duration,
/// http endpoint to receive periodic metric updates
#[clap(long)]
metric_collection_endpoint: Option<String>,
/// how often metrics should be sent to a collection endpoint
#[clap(long)]
metric_collection_interval: Option<String>,
/// cache for `wake_compute` api method (use `size=0` to disable)
#[clap(long, default_value = config::CacheOptions::CACHE_DEFAULT_OPTIONS)]
wake_compute_cache: String,
@@ -186,40 +180,31 @@ struct ProxyCliArgs {
/// Wake compute rate limiter max number of requests per second.
#[clap(long, default_values_t = RateBucketInfo::DEFAULT_SET)]
wake_compute_limit: Vec<RateBucketInfo>,
/// Redis rate limiter max number of requests per second.
#[clap(long, default_values_t = RateBucketInfo::DEFAULT_REDIS_SET)]
redis_rps_limit: Vec<RateBucketInfo>,
/// Cancellation channel size (max queue size for redis kv client)
#[clap(long, default_value_t = 1024)]
cancellation_ch_size: usize,
/// Cancellation ops batch size for redis
#[clap(long, default_value_t = 8)]
cancellation_batch_size: usize,
/// cache for `allowed_ips` (use `size=0` to disable)
#[clap(long, default_value = config::CacheOptions::CACHE_DEFAULT_OPTIONS)]
allowed_ips_cache: String,
/// cache for `role_secret` (use `size=0` to disable)
#[clap(long, default_value = config::CacheOptions::CACHE_DEFAULT_OPTIONS)]
role_secret_cache: String,
/// redis url for notifications (if empty, redis_host:port will be used for both notifications and streaming connections)
#[clap(long)]
redis_notifications: Option<String>,
/// what from the available authentications type to use for the regional redis we have. Supported are "irsa" and "plain".
/// redis url for plain authentication
#[clap(long, alias("redis-notifications"))]
redis_plain: Option<String>,
/// what from the available authentications type to use for redis. Supported are "irsa" and "plain".
#[clap(long, default_value = "irsa")]
redis_auth_type: String,
/// redis host for streaming connections (might be different from the notifications host)
/// redis host for irsa authentication
#[clap(long)]
redis_host: Option<String>,
/// redis port for streaming connections (might be different from the notifications host)
/// redis port for irsa authentication
#[clap(long)]
redis_port: Option<u16>,
/// redis cluster name, used in aws elasticache
/// redis cluster name for irsa authentication
#[clap(long)]
redis_cluster_name: Option<String>,
/// redis user_id, used in aws elasticache
/// redis user_id for irsa authentication
#[clap(long)]
redis_user_id: Option<String>,
/// aws region to retrieve credentials
/// aws region for irsa authentication
#[clap(long, default_value_t = String::new())]
aws_region: String,
/// cache for `project_info` (use `size=0` to disable)
@@ -231,6 +216,12 @@ struct ProxyCliArgs {
#[clap(flatten)]
parquet_upload: ParquetUploadArgs,
/// http endpoint to receive periodic metric updates
#[clap(long)]
metric_collection_endpoint: Option<String>,
/// how often metrics should be sent to a collection endpoint
#[clap(long)]
metric_collection_interval: Option<String>,
/// interval for backup metric collection
#[clap(long, default_value = "10m", value_parser = humantime::parse_duration)]
metric_backup_collection_interval: std::time::Duration,
@@ -243,6 +234,7 @@ struct ProxyCliArgs {
/// Size of each event is no more than 400 bytes, so 2**22 is about 200MB before the compression.
#[clap(long, default_value = "4194304")]
metric_backup_collection_chunk_size: usize,
/// Whether to retry the connection to the compute node
#[clap(long, default_value = config::RetryConfig::CONNECT_TO_COMPUTE_DEFAULT_VALUES)]
connect_to_compute_retry: String,
@@ -370,7 +362,7 @@ pub async fn run() -> anyhow::Result<()> {
Either::Right(auth_backend) => info!("Authentication backend: {auth_backend:?}"),
}
info!("Using region: {}", args.aws_region);
let (regional_redis_client, redis_notifications_client) = configure_redis(&args).await?;
let redis_client = configure_redis(&args).await?;
// Check that we can bind to address before further initialization
info!("Starting http on {}", args.http);
@@ -425,13 +417,6 @@ pub async fn run() -> anyhow::Result<()> {
let cancellation_token = CancellationToken::new();
let redis_rps_limit = Vec::leak(args.redis_rps_limit.clone());
RateBucketInfo::validate(redis_rps_limit)?;
let redis_kv_client = regional_redis_client
.as_ref()
.map(|redis_publisher| RedisKVClient::new(redis_publisher.clone(), redis_rps_limit));
let cancellation_handler = Arc::new(CancellationHandler::new(&config.connect_to_compute));
let endpoint_rate_limiter = Arc::new(EndpointRateLimiter::new_with_shards(
@@ -446,7 +431,7 @@ pub async fn run() -> anyhow::Result<()> {
match auth_backend {
Either::Left(auth_backend) => {
if let Some(proxy_listener) = proxy_listener {
client_tasks.spawn(crate::proxy::task_main(
client_tasks.spawn(crate::pglb::task_main(
config,
auth_backend,
proxy_listener,
@@ -528,6 +513,7 @@ pub async fn run() -> anyhow::Result<()> {
client_tasks.spawn(crate::context::parquet::worker(
cancellation_token.clone(),
args.parquet_upload,
args.region,
));
// maintenance tasks. these never return unless there's an error
@@ -561,32 +547,17 @@ pub async fn run() -> anyhow::Result<()> {
#[cfg_attr(not(any(test, feature = "testing")), expect(irrefutable_let_patterns))]
if let Either::Left(auth::Backend::ControlPlane(api, ())) = &auth_backend {
if let crate::control_plane::client::ControlPlaneClient::ProxyV1(api) = &**api {
match (redis_notifications_client, regional_redis_client.clone()) {
(None, None) => {}
(client1, client2) => {
let cache = api.caches.project_info.clone();
if let Some(client) = client1 {
maintenance_tasks.spawn(notifications::task_main(
client,
cache.clone(),
args.region.clone(),
));
}
if let Some(client) = client2 {
maintenance_tasks.spawn(notifications::task_main(
client,
cache.clone(),
args.region.clone(),
));
}
maintenance_tasks.spawn(async move { cache.clone().gc_worker().await });
}
}
if let Some(client) = redis_client {
// project info cache and invalidation of that cache.
let cache = api.caches.project_info.clone();
maintenance_tasks.spawn(notifications::task_main(client.clone(), cache.clone()));
maintenance_tasks.spawn(async move { cache.clone().gc_worker().await });
// Try to connect to Redis 3 times with 1 + (0..0.1) second interval.
// This prevents immediate exit and pod restart,
// which can cause hammering of the redis in case of connection issues.
if let Some(mut redis_kv_client) = redis_kv_client {
// Try to connect to Redis 3 times with 1 + (0..0.1) second interval.
// This prevents immediate exit and pod restart,
// which can cause hammering of the redis in case of connection issues.
// cancellation key management
let mut redis_kv_client = RedisKVClient::new(client.clone());
for attempt in (0..3).with_position() {
match redis_kv_client.try_connect().await {
Ok(()) => {
@@ -611,14 +582,12 @@ pub async fn run() -> anyhow::Result<()> {
}
}
}
}
if let Some(regional_redis_client) = regional_redis_client {
// listen for notifications of new projects/endpoints/branches
let cache = api.caches.endpoints_cache.clone();
let con = regional_redis_client;
let span = tracing::info_span!("endpoints_cache");
maintenance_tasks.spawn(
async move { cache.do_read(con, cancellation_token.clone()).await }
async move { cache.do_read(client, cancellation_token.clone()).await }
.instrument(span),
);
}
@@ -765,7 +734,6 @@ fn build_config(args: &ProxyCliArgs) -> anyhow::Result<&'static ProxyConfig> {
authentication_config,
proxy_protocol_v2: args.proxy_protocol_v2,
handshake_timeout: args.handshake_timeout,
region: args.region.clone(),
wake_compute_retry_config: config::RetryConfig::parse(&args.wake_compute_retry)?,
connect_compute_locks,
connect_to_compute: compute_config,
@@ -942,21 +910,18 @@ fn build_auth_backend(
async fn configure_redis(
args: &ProxyCliArgs,
) -> anyhow::Result<(
Option<ConnectionWithCredentialsProvider>,
Option<ConnectionWithCredentialsProvider>,
)> {
) -> anyhow::Result<Option<ConnectionWithCredentialsProvider>> {
// TODO: untangle the config args
let regional_redis_client = match (args.redis_auth_type.as_str(), &args.redis_notifications) {
("plain", redis_url) => match redis_url {
let redis_client = match &*args.redis_auth_type {
"plain" => match &args.redis_plain {
None => {
bail!("plain auth requires redis_notifications to be set");
bail!("plain auth requires redis_plain to be set");
}
Some(url) => {
Some(ConnectionWithCredentialsProvider::new_with_static_credentials(url.clone()))
}
},
("irsa", _) => match (&args.redis_host, args.redis_port) {
"irsa" => match (&args.redis_host, args.redis_port) {
(Some(host), Some(port)) => Some(
ConnectionWithCredentialsProvider::new_with_credentials_provider(
host.clone(),
@@ -980,18 +945,12 @@ async fn configure_redis(
bail!("redis-host and redis-port must be specified together");
}
},
_ => {
bail!("unknown auth type given");
auth_type => {
bail!("unknown auth type {auth_type:?} given")
}
};
let redis_notifications_client = if let Some(url) = &args.redis_notifications {
Some(ConnectionWithCredentialsProvider::new_with_static_credentials(&**url))
} else {
regional_redis_client.clone()
};
Ok((regional_redis_client, redis_notifications_client))
Ok(redis_client)
}

View File

@@ -30,7 +30,7 @@ use super::{Cache, timed_lru};
///
/// * There's an API for immediate invalidation (removal) of a cache entry;
/// It's useful in case we know for sure that the entry is no longer correct.
/// See [`timed_lru::LookupInfo`] & [`timed_lru::Cached`] for more information.
/// See [`timed_lru::Cached`] for more information.
///
/// * Expired entries are kept in the cache, until they are evicted by the LRU policy,
/// or by a successful lookup (i.e. the entry hasn't expired yet).
@@ -54,7 +54,7 @@ pub(crate) struct TimedLru<K, V> {
impl<K: Hash + Eq, V> Cache for TimedLru<K, V> {
type Key = K;
type Value = V;
type LookupInfo<Key> = LookupInfo<Key>;
type LookupInfo<Key> = Key;
fn invalidate(&self, info: &Self::LookupInfo<K>) {
self.invalidate_raw(info);
@@ -87,30 +87,24 @@ impl<K: Hash + Eq, V> TimedLru<K, V> {
/// Drop an entry from the cache if it's outdated.
#[tracing::instrument(level = "debug", fields(cache = self.name), skip_all)]
fn invalidate_raw(&self, info: &LookupInfo<K>) {
let now = Instant::now();
fn invalidate_raw(&self, key: &K) {
// Do costly things before taking the lock.
let mut cache = self.cache.lock();
let raw_entry = match cache.raw_entry_mut().from_key(&info.key) {
let entry = match cache.raw_entry_mut().from_key(key) {
RawEntryMut::Vacant(_) => return,
RawEntryMut::Occupied(x) => x,
RawEntryMut::Occupied(x) => x.remove(),
};
// Remove the entry if it was created prior to lookup timestamp.
let entry = raw_entry.get();
let (created_at, expires_at) = (entry.created_at, entry.expires_at);
let should_remove = created_at <= info.created_at || expires_at <= now;
if should_remove {
raw_entry.remove();
}
drop(cache); // drop lock before logging
let Entry {
created_at,
expires_at,
..
} = entry;
debug!(
created_at = format_args!("{created_at:?}"),
expires_at = format_args!("{expires_at:?}"),
entry_removed = should_remove,
?created_at,
?expires_at,
"processed a cache entry invalidation event"
);
}
@@ -215,10 +209,10 @@ impl<K: Hash + Eq + Clone, V: Clone> TimedLru<K, V> {
}
pub(crate) fn insert_unit(&self, key: K, value: V) -> (Option<V>, Cached<&Self, ()>) {
let (created_at, old) = self.insert_raw(key.clone(), value);
let (_, old) = self.insert_raw(key.clone(), value);
let cached = Cached {
token: Some((self, LookupInfo { created_at, key })),
token: Some((self, key)),
value: (),
};
@@ -255,28 +249,9 @@ impl<K: Hash + Eq, V: Clone> TimedLru<K, V> {
K: Borrow<Q> + Clone,
Q: Hash + Eq + ?Sized,
{
self.get_raw(key, |key, entry| {
let info = LookupInfo {
created_at: entry.created_at,
key: key.clone(),
};
Cached {
token: Some((self, info)),
value: entry.value.clone(),
}
self.get_raw(key, |key, entry| Cached {
token: Some((self, key.clone())),
value: entry.value.clone(),
})
}
}
/// Lookup information for key invalidation.
pub(crate) struct LookupInfo<K> {
/// Time of creation of a cache [`Entry`].
/// We use this during invalidation lookups to prevent eviction of a newer
/// entry sharing the same key (it might've been inserted by a different
/// task after we got the entry we're trying to invalidate now).
created_at: Instant,
/// Search by this key.
key: K,
}

View File

@@ -236,7 +236,7 @@ impl AuthInfo {
&self,
ctx: &RequestContext,
compute: &mut ComputeConnection,
user_info: ComputeUserInfo,
user_info: &ComputeUserInfo,
) -> Result<PostgresSettings, PostgresError> {
// client config with stubbed connect info.
// TODO(conrad): should we rewrite this to bypass tokio-postgres2 entirely,
@@ -272,7 +272,7 @@ impl AuthInfo {
secret_key,
},
compute.hostname.to_string(),
user_info,
user_info.clone(),
);
Ok(PostgresSettings {

View File

@@ -24,7 +24,6 @@ pub struct ProxyConfig {
pub authentication_config: AuthenticationConfig,
pub rest_config: RestConfig,
pub proxy_protocol_v2: ProxyProtocolV2,
pub region: String,
pub handshake_timeout: Duration,
pub wake_compute_retry_config: RetryConfig,
pub connect_compute_locks: ApiLocks<Host>,

View File

@@ -11,11 +11,12 @@ use crate::config::{ProxyConfig, ProxyProtocolV2};
use crate::context::RequestContext;
use crate::error::ReportableError;
use crate::metrics::{Metrics, NumClientConnectionsGuard};
use crate::pglb::ClientRequestError;
use crate::pglb::handshake::{HandshakeData, handshake};
use crate::pglb::passthrough::ProxyPassthrough;
use crate::protocol2::{ConnectHeader, ConnectionInfo, read_proxy_protocol};
use crate::proxy::connect_compute::{TcpMechanism, connect_to_compute};
use crate::proxy::{ClientRequestError, ErrorSource, prepare_client_connection};
use crate::proxy::{ErrorSource, finish_client_init};
use crate::util::run_until_cancelled;
pub async fn task_main(
@@ -89,12 +90,7 @@ pub async fn task_main(
}
}
let ctx = RequestContext::new(
session_id,
conn_info,
crate::metrics::Protocol::Tcp,
&config.region,
);
let ctx = RequestContext::new(session_id, conn_info, crate::metrics::Protocol::Tcp);
let res = handle_client(
config,
@@ -231,13 +227,13 @@ pub(crate) async fn handle_client<S: AsyncRead + AsyncWrite + Unpin + Send>(
.await?;
let pg_settings = auth_info
.authenticate(ctx, &mut node, user_info)
.authenticate(ctx, &mut node, &user_info)
.or_else(|e| async { Err(stream.throw_error(e, Some(ctx)).await) })
.await?;
let session = cancellation_handler.get_key();
prepare_client_connection(&pg_settings, *session.key(), &mut stream);
finish_client_init(&pg_settings, *session.key(), &mut stream);
let stream = stream.flush_and_into_inner().await?;
let session_id = ctx.session_id();

View File

@@ -46,7 +46,6 @@ struct RequestContextInner {
pub(crate) session_id: Uuid,
pub(crate) protocol: Protocol,
first_packet: chrono::DateTime<Utc>,
region: &'static str,
pub(crate) span: Span,
// filled in as they are discovered
@@ -94,7 +93,6 @@ impl Clone for RequestContext {
session_id: inner.session_id,
protocol: inner.protocol,
first_packet: inner.first_packet,
region: inner.region,
span: info_span!("background_task"),
project: inner.project,
@@ -124,12 +122,7 @@ impl Clone for RequestContext {
}
impl RequestContext {
pub fn new(
session_id: Uuid,
conn_info: ConnectionInfo,
protocol: Protocol,
region: &'static str,
) -> Self {
pub fn new(session_id: Uuid, conn_info: ConnectionInfo, protocol: Protocol) -> Self {
// TODO: be careful with long lived spans
let span = info_span!(
"connect_request",
@@ -145,7 +138,6 @@ impl RequestContext {
session_id,
protocol,
first_packet: Utc::now(),
region,
span,
project: None,
@@ -179,7 +171,7 @@ impl RequestContext {
let ip = IpAddr::from([127, 0, 0, 1]);
let addr = SocketAddr::new(ip, 5432);
let conn_info = ConnectionInfo { addr, extra: None };
RequestContext::new(Uuid::now_v7(), conn_info, Protocol::Tcp, "test")
RequestContext::new(Uuid::now_v7(), conn_info, Protocol::Tcp)
}
pub(crate) fn console_application_name(&self) -> String {

View File

@@ -74,7 +74,7 @@ pub(crate) const FAILED_UPLOAD_MAX_RETRIES: u32 = 10;
#[derive(parquet_derive::ParquetRecordWriter)]
pub(crate) struct RequestData {
region: &'static str,
region: String,
protocol: &'static str,
/// Must be UTC. The derive macro doesn't like the timezones
timestamp: chrono::NaiveDateTime,
@@ -147,7 +147,7 @@ impl From<&RequestContextInner> for RequestData {
}),
jwt_issuer: value.jwt_issuer.clone(),
protocol: value.protocol.as_str(),
region: value.region,
region: String::new(),
error: value.error_kind.as_ref().map(|e| e.to_metric_label()),
success: value.success,
cold_start_info: value.cold_start_info.as_str(),
@@ -167,6 +167,7 @@ impl From<&RequestContextInner> for RequestData {
pub async fn worker(
cancellation_token: CancellationToken,
config: ParquetUploadArgs,
region: String,
) -> anyhow::Result<()> {
let Some(remote_storage_config) = config.parquet_upload_remote_storage else {
tracing::warn!("parquet request upload: no s3 bucket configured");
@@ -232,12 +233,17 @@ pub async fn worker(
.context("remote storage for disconnect events init")?;
let parquet_config_disconnect = parquet_config.clone();
tokio::try_join!(
worker_inner(storage, rx, parquet_config),
worker_inner(storage_disconnect, rx_disconnect, parquet_config_disconnect)
worker_inner(storage, rx, parquet_config, &region),
worker_inner(
storage_disconnect,
rx_disconnect,
parquet_config_disconnect,
&region
)
)
.map(|_| ())
} else {
worker_inner(storage, rx, parquet_config).await
worker_inner(storage, rx, parquet_config, &region).await
}
}
@@ -257,6 +263,7 @@ async fn worker_inner(
storage: GenericRemoteStorage,
rx: impl Stream<Item = RequestData>,
config: ParquetConfig,
region: &str,
) -> anyhow::Result<()> {
#[cfg(any(test, feature = "testing"))]
let storage = if config.test_remote_failures > 0 {
@@ -277,7 +284,8 @@ async fn worker_inner(
let mut last_upload = time::Instant::now();
let mut len = 0;
while let Some(row) = rx.next().await {
while let Some(mut row) = rx.next().await {
region.clone_into(&mut row.region);
rows.push(row);
let force = last_upload.elapsed() > config.max_duration;
if rows.len() == config.rows_per_group || force {
@@ -533,7 +541,7 @@ mod tests {
auth_method: None,
jwt_issuer: None,
protocol: ["tcp", "ws", "http"][rng.gen_range(0..3)],
region: "us-east-1",
region: String::new(),
error: None,
success: rng.r#gen(),
cold_start_info: "no",
@@ -565,7 +573,9 @@ mod tests {
.await
.unwrap();
worker_inner(storage, rx, config).await.unwrap();
worker_inner(storage, rx, config, "us-east-1")
.await
.unwrap();
let mut files = WalkDir::new(tmpdir.as_std_path())
.into_iter()

View File

@@ -8,10 +8,10 @@ use crate::config::TlsConfig;
use crate::context::RequestContext;
use crate::error::ReportableError;
use crate::metrics::Metrics;
use crate::pglb::TlsRequired;
use crate::pqproto::{
BeMessage, CancelKeyData, FeStartupPacket, ProtocolVersion, StartupMessageParams,
};
use crate::proxy::TlsRequired;
use crate::stream::{PqStream, Stream, StreamUpgradeError};
use crate::tls::PG_ALPN_PROTOCOL;

View File

@@ -2,3 +2,332 @@ pub mod copy_bidirectional;
pub mod handshake;
pub mod inprocess;
pub mod passthrough;
use std::sync::Arc;
use futures::FutureExt;
use smol_str::ToSmolStr;
use thiserror::Error;
use tokio::io::{AsyncRead, AsyncWrite};
use tokio_util::sync::CancellationToken;
use tracing::{Instrument, debug, error, info, warn};
use crate::auth;
use crate::cancellation::{self, CancellationHandler};
use crate::config::{ProxyConfig, ProxyProtocolV2, TlsConfig};
use crate::context::RequestContext;
use crate::error::{ReportableError, UserFacingError};
use crate::metrics::{Metrics, NumClientConnectionsGuard};
pub use crate::pglb::copy_bidirectional::ErrorSource;
use crate::pglb::handshake::{HandshakeData, HandshakeError, handshake};
use crate::pglb::passthrough::ProxyPassthrough;
use crate::protocol2::{ConnectHeader, ConnectionInfo, ConnectionInfoExtra, read_proxy_protocol};
use crate::proxy::handle_client;
use crate::rate_limiter::EndpointRateLimiter;
use crate::stream::Stream;
use crate::util::run_until_cancelled;
pub const ERR_INSECURE_CONNECTION: &str = "connection is insecure (try using `sslmode=require`)";
#[derive(Error, Debug)]
#[error("{ERR_INSECURE_CONNECTION}")]
pub struct TlsRequired;
impl ReportableError for TlsRequired {
fn get_error_kind(&self) -> crate::error::ErrorKind {
crate::error::ErrorKind::User
}
}
impl UserFacingError for TlsRequired {}
pub async fn task_main(
config: &'static ProxyConfig,
auth_backend: &'static auth::Backend<'static, ()>,
listener: tokio::net::TcpListener,
cancellation_token: CancellationToken,
cancellation_handler: Arc<CancellationHandler>,
endpoint_rate_limiter: Arc<EndpointRateLimiter>,
) -> anyhow::Result<()> {
scopeguard::defer! {
info!("proxy has shut down");
}
// When set for the server socket, the keepalive setting
// will be inherited by all accepted client sockets.
socket2::SockRef::from(&listener).set_keepalive(true)?;
let connections = tokio_util::task::task_tracker::TaskTracker::new();
let cancellations = tokio_util::task::task_tracker::TaskTracker::new();
while let Some(accept_result) =
run_until_cancelled(listener.accept(), &cancellation_token).await
{
let (socket, peer_addr) = accept_result?;
let conn_gauge = Metrics::get()
.proxy
.client_connections
.guard(crate::metrics::Protocol::Tcp);
let session_id = uuid::Uuid::new_v4();
let cancellation_handler = Arc::clone(&cancellation_handler);
let cancellations = cancellations.clone();
debug!(protocol = "tcp", %session_id, "accepted new TCP connection");
let endpoint_rate_limiter2 = endpoint_rate_limiter.clone();
connections.spawn(async move {
let (socket, conn_info) = match config.proxy_protocol_v2 {
ProxyProtocolV2::Required => {
match read_proxy_protocol(socket).await {
Err(e) => {
warn!("per-client task finished with an error: {e:#}");
return;
}
// our load balancers will not send any more data. let's just exit immediately
Ok((_socket, ConnectHeader::Local)) => {
debug!("healthcheck received");
return;
}
Ok((socket, ConnectHeader::Proxy(info))) => (socket, info),
}
}
// ignore the header - it cannot be confused for a postgres or http connection so will
// error later.
ProxyProtocolV2::Rejected => (
socket,
ConnectionInfo {
addr: peer_addr,
extra: None,
},
),
};
match socket.set_nodelay(true) {
Ok(()) => {}
Err(e) => {
error!(
"per-client task finished with an error: failed to set socket option: {e:#}"
);
return;
}
}
let ctx = RequestContext::new(session_id, conn_info, crate::metrics::Protocol::Tcp);
let res = handle_connection(
config,
auth_backend,
&ctx,
cancellation_handler,
socket,
ClientMode::Tcp,
endpoint_rate_limiter2,
conn_gauge,
cancellations,
)
.instrument(ctx.span())
.boxed()
.await;
match res {
Err(e) => {
ctx.set_error_kind(e.get_error_kind());
warn!(parent: &ctx.span(), "per-client task finished with an error: {e:#}");
}
Ok(None) => {
ctx.set_success();
}
Ok(Some(p)) => {
ctx.set_success();
let _disconnect = ctx.log_connect();
match p.proxy_pass().await {
Ok(()) => {}
Err(ErrorSource::Client(e)) => {
warn!(
?session_id,
"per-client task finished with an IO error from the client: {e:#}"
);
}
Err(ErrorSource::Compute(e)) => {
error!(
?session_id,
"per-client task finished with an IO error from the compute: {e:#}"
);
}
}
}
}
});
}
connections.close();
cancellations.close();
drop(listener);
// Drain connections
connections.wait().await;
cancellations.wait().await;
Ok(())
}
pub(crate) enum ClientMode {
Tcp,
Websockets { hostname: Option<String> },
}
/// Abstracts the logic of handling TCP vs WS clients
impl ClientMode {
pub fn allow_cleartext(&self) -> bool {
match self {
ClientMode::Tcp => false,
ClientMode::Websockets { .. } => true,
}
}
pub fn hostname<'a, S>(&'a self, s: &'a Stream<S>) -> Option<&'a str> {
match self {
ClientMode::Tcp => s.sni_hostname(),
ClientMode::Websockets { hostname } => hostname.as_deref(),
}
}
pub fn handshake_tls<'a>(&self, tls: Option<&'a TlsConfig>) -> Option<&'a TlsConfig> {
match self {
ClientMode::Tcp => tls,
// TLS is None here if using websockets, because the connection is already encrypted.
ClientMode::Websockets { .. } => None,
}
}
}
#[derive(Debug, Error)]
// almost all errors should be reported to the user, but there's a few cases where we cannot
// 1. Cancellation: we are not allowed to tell the client any cancellation statuses for security reasons
// 2. Handshake: handshake reports errors if it can, otherwise if the handshake fails due to protocol violation,
// we cannot be sure the client even understands our error message
// 3. PrepareClient: The client disconnected, so we can't tell them anyway...
pub(crate) enum ClientRequestError {
#[error("{0}")]
Cancellation(#[from] cancellation::CancelError),
#[error("{0}")]
Handshake(#[from] HandshakeError),
#[error("{0}")]
HandshakeTimeout(#[from] tokio::time::error::Elapsed),
#[error("{0}")]
PrepareClient(#[from] std::io::Error),
#[error("{0}")]
ReportedError(#[from] crate::stream::ReportedError),
}
impl ReportableError for ClientRequestError {
fn get_error_kind(&self) -> crate::error::ErrorKind {
match self {
ClientRequestError::Cancellation(e) => e.get_error_kind(),
ClientRequestError::Handshake(e) => e.get_error_kind(),
ClientRequestError::HandshakeTimeout(_) => crate::error::ErrorKind::RateLimit,
ClientRequestError::ReportedError(e) => e.get_error_kind(),
ClientRequestError::PrepareClient(_) => crate::error::ErrorKind::ClientDisconnect,
}
}
}
#[allow(clippy::too_many_arguments)]
pub(crate) async fn handle_connection<S: AsyncRead + AsyncWrite + Unpin + Send>(
config: &'static ProxyConfig,
auth_backend: &'static auth::Backend<'static, ()>,
ctx: &RequestContext,
cancellation_handler: Arc<CancellationHandler>,
client: S,
mode: ClientMode,
endpoint_rate_limiter: Arc<EndpointRateLimiter>,
conn_gauge: NumClientConnectionsGuard<'static>,
cancellations: tokio_util::task::task_tracker::TaskTracker,
) -> Result<Option<ProxyPassthrough<S>>, ClientRequestError> {
debug!(
protocol = %ctx.protocol(),
"handling interactive connection from client"
);
let metrics = &Metrics::get().proxy;
let proto = ctx.protocol();
let request_gauge = metrics.connection_requests.guard(proto);
let tls = config.tls_config.load();
let tls = tls.as_deref();
let record_handshake_error = !ctx.has_private_peer_addr();
let pause = ctx.latency_timer_pause(crate::metrics::Waiting::Client);
let do_handshake = handshake(ctx, client, mode.handshake_tls(tls), record_handshake_error);
let (mut client, params) = match tokio::time::timeout(config.handshake_timeout, do_handshake)
.await??
{
HandshakeData::Startup(client, params) => (client, params),
HandshakeData::Cancel(cancel_key_data) => {
// spawn a task to cancel the session, but don't wait for it
cancellations.spawn({
let cancellation_handler_clone = Arc::clone(&cancellation_handler);
let ctx = ctx.clone();
let cancel_span = tracing::span!(parent: None, tracing::Level::INFO, "cancel_session", session_id = ?ctx.session_id());
cancel_span.follows_from(tracing::Span::current());
async move {
cancellation_handler_clone
.cancel_session(
cancel_key_data,
ctx,
config.authentication_config.ip_allowlist_check_enabled,
config.authentication_config.is_vpc_acccess_proxy,
auth_backend.get_api(),
)
.await
.inspect_err(|e | debug!(error = ?e, "cancel_session failed")).ok();
}.instrument(cancel_span)
});
return Ok(None);
}
};
drop(pause);
ctx.set_db_options(params.clone());
let common_names = tls.map(|tls| &tls.common_names);
let (node, cancel_on_shutdown) = handle_client(
config,
auth_backend,
ctx,
cancellation_handler,
&mut client,
&mode,
endpoint_rate_limiter,
common_names,
&params,
)
.await?;
let client = client.flush_and_into_inner().await?;
let private_link_id = match ctx.extra() {
Some(ConnectionInfoExtra::Aws { vpce_id }) => Some(vpce_id.clone()),
Some(ConnectionInfoExtra::Azure { link_id }) => Some(link_id.to_smolstr()),
None => None,
};
Ok(Some(ProxyPassthrough {
client,
compute: node.stream,
aux: node.aux,
private_link_id,
_cancel_on_shutdown: cancel_on_shutdown,
_req: request_gauge,
_conn: conn_gauge,
_db_conn: node.guage,
}))
}

View File

@@ -112,7 +112,7 @@ where
let node_info = if !node_info.cached() || !err.should_retry_wake_compute() {
// If we just recieved this from cplane and didn't get it from cache, we shouldn't retry.
// Do not need to retrieve a new node_info, just return the old one.
if should_retry(&err, num_retries, compute.retry) {
if !should_retry(&err, num_retries, compute.retry) {
Metrics::get().proxy.retries_metric.observe(
RetriesMetricGroup {
outcome: ConnectOutcome::Failed,

View File

@@ -5,328 +5,64 @@ pub(crate) mod connect_compute;
pub(crate) mod retry;
pub(crate) mod wake_compute;
use std::collections::HashSet;
use std::convert::Infallible;
use std::sync::Arc;
use futures::FutureExt;
use itertools::Itertools;
use once_cell::sync::OnceCell;
use regex::Regex;
use serde::{Deserialize, Serialize};
use smol_str::{SmolStr, ToSmolStr, format_smolstr};
use thiserror::Error;
use smol_str::{SmolStr, format_smolstr};
use tokio::io::{AsyncRead, AsyncWrite};
use tokio_util::sync::CancellationToken;
use tracing::{Instrument, debug, error, info, warn};
use tokio::sync::oneshot;
use tracing::Instrument;
use crate::cancellation::{self, CancellationHandler};
use crate::config::{ProxyConfig, ProxyProtocolV2, TlsConfig};
use crate::cache::Cache;
use crate::cancellation::CancellationHandler;
use crate::compute::ComputeConnection;
use crate::config::ProxyConfig;
use crate::context::RequestContext;
use crate::error::{ReportableError, UserFacingError};
use crate::metrics::{Metrics, NumClientConnectionsGuard};
use crate::control_plane::client::ControlPlaneClient;
pub use crate::pglb::copy_bidirectional::{ErrorSource, copy_bidirectional_client_compute};
use crate::pglb::handshake::{HandshakeData, HandshakeError, handshake};
use crate::pglb::passthrough::ProxyPassthrough;
use crate::pglb::{ClientMode, ClientRequestError};
use crate::pqproto::{BeMessage, CancelKeyData, StartupMessageParams};
use crate::protocol2::{ConnectHeader, ConnectionInfo, ConnectionInfoExtra, read_proxy_protocol};
use crate::proxy::connect_compute::{TcpMechanism, connect_to_compute};
use crate::proxy::retry::ShouldRetryWakeCompute;
use crate::rate_limiter::EndpointRateLimiter;
use crate::stream::{PqStream, Stream};
use crate::types::EndpointCacheKey;
use crate::util::run_until_cancelled;
use crate::{auth, compute};
const ERR_INSECURE_CONNECTION: &str = "connection is insecure (try using `sslmode=require`)";
#[derive(Error, Debug)]
#[error("{ERR_INSECURE_CONNECTION}")]
pub struct TlsRequired;
impl ReportableError for TlsRequired {
fn get_error_kind(&self) -> crate::error::ErrorKind {
crate::error::ErrorKind::User
}
}
impl UserFacingError for TlsRequired {}
pub async fn task_main(
config: &'static ProxyConfig,
auth_backend: &'static auth::Backend<'static, ()>,
listener: tokio::net::TcpListener,
cancellation_token: CancellationToken,
cancellation_handler: Arc<CancellationHandler>,
endpoint_rate_limiter: Arc<EndpointRateLimiter>,
) -> anyhow::Result<()> {
scopeguard::defer! {
info!("proxy has shut down");
}
// When set for the server socket, the keepalive setting
// will be inherited by all accepted client sockets.
socket2::SockRef::from(&listener).set_keepalive(true)?;
let connections = tokio_util::task::task_tracker::TaskTracker::new();
let cancellations = tokio_util::task::task_tracker::TaskTracker::new();
while let Some(accept_result) =
run_until_cancelled(listener.accept(), &cancellation_token).await
{
let (socket, peer_addr) = accept_result?;
let conn_gauge = Metrics::get()
.proxy
.client_connections
.guard(crate::metrics::Protocol::Tcp);
let session_id = uuid::Uuid::new_v4();
let cancellation_handler = Arc::clone(&cancellation_handler);
let cancellations = cancellations.clone();
debug!(protocol = "tcp", %session_id, "accepted new TCP connection");
let endpoint_rate_limiter2 = endpoint_rate_limiter.clone();
connections.spawn(async move {
let (socket, conn_info) = match config.proxy_protocol_v2 {
ProxyProtocolV2::Required => {
match read_proxy_protocol(socket).await {
Err(e) => {
warn!("per-client task finished with an error: {e:#}");
return;
}
// our load balancers will not send any more data. let's just exit immediately
Ok((_socket, ConnectHeader::Local)) => {
debug!("healthcheck received");
return;
}
Ok((socket, ConnectHeader::Proxy(info))) => (socket, info),
}
}
// ignore the header - it cannot be confused for a postgres or http connection so will
// error later.
ProxyProtocolV2::Rejected => (
socket,
ConnectionInfo {
addr: peer_addr,
extra: None,
},
),
};
match socket.set_nodelay(true) {
Ok(()) => {}
Err(e) => {
error!(
"per-client task finished with an error: failed to set socket option: {e:#}"
);
return;
}
}
let ctx = RequestContext::new(
session_id,
conn_info,
crate::metrics::Protocol::Tcp,
&config.region,
);
let res = handle_client(
config,
auth_backend,
&ctx,
cancellation_handler,
socket,
ClientMode::Tcp,
endpoint_rate_limiter2,
conn_gauge,
cancellations,
)
.instrument(ctx.span())
.boxed()
.await;
match res {
Err(e) => {
ctx.set_error_kind(e.get_error_kind());
warn!(parent: &ctx.span(), "per-client task finished with an error: {e:#}");
}
Ok(None) => {
ctx.set_success();
}
Ok(Some(p)) => {
ctx.set_success();
let _disconnect = ctx.log_connect();
match p.proxy_pass().await {
Ok(()) => {}
Err(ErrorSource::Client(e)) => {
warn!(
?session_id,
"per-client task finished with an IO error from the client: {e:#}"
);
}
Err(ErrorSource::Compute(e)) => {
error!(
?session_id,
"per-client task finished with an IO error from the compute: {e:#}"
);
}
}
}
}
});
}
connections.close();
cancellations.close();
drop(listener);
// Drain connections
connections.wait().await;
cancellations.wait().await;
Ok(())
}
pub(crate) enum ClientMode {
Tcp,
Websockets { hostname: Option<String> },
}
/// Abstracts the logic of handling TCP vs WS clients
impl ClientMode {
pub(crate) fn allow_cleartext(&self) -> bool {
match self {
ClientMode::Tcp => false,
ClientMode::Websockets { .. } => true,
}
}
fn hostname<'a, S>(&'a self, s: &'a Stream<S>) -> Option<&'a str> {
match self {
ClientMode::Tcp => s.sni_hostname(),
ClientMode::Websockets { hostname } => hostname.as_deref(),
}
}
fn handshake_tls<'a>(&self, tls: Option<&'a TlsConfig>) -> Option<&'a TlsConfig> {
match self {
ClientMode::Tcp => tls,
// TLS is None here if using websockets, because the connection is already encrypted.
ClientMode::Websockets { .. } => None,
}
}
}
#[derive(Debug, Error)]
// almost all errors should be reported to the user, but there's a few cases where we cannot
// 1. Cancellation: we are not allowed to tell the client any cancellation statuses for security reasons
// 2. Handshake: handshake reports errors if it can, otherwise if the handshake fails due to protocol violation,
// we cannot be sure the client even understands our error message
// 3. PrepareClient: The client disconnected, so we can't tell them anyway...
pub(crate) enum ClientRequestError {
#[error("{0}")]
Cancellation(#[from] cancellation::CancelError),
#[error("{0}")]
Handshake(#[from] HandshakeError),
#[error("{0}")]
HandshakeTimeout(#[from] tokio::time::error::Elapsed),
#[error("{0}")]
PrepareClient(#[from] std::io::Error),
#[error("{0}")]
ReportedError(#[from] crate::stream::ReportedError),
}
impl ReportableError for ClientRequestError {
fn get_error_kind(&self) -> crate::error::ErrorKind {
match self {
ClientRequestError::Cancellation(e) => e.get_error_kind(),
ClientRequestError::Handshake(e) => e.get_error_kind(),
ClientRequestError::HandshakeTimeout(_) => crate::error::ErrorKind::RateLimit,
ClientRequestError::ReportedError(e) => e.get_error_kind(),
ClientRequestError::PrepareClient(_) => crate::error::ErrorKind::ClientDisconnect,
}
}
}
#[allow(clippy::too_many_arguments)]
pub(crate) async fn handle_client<S: AsyncRead + AsyncWrite + Unpin + Send>(
config: &'static ProxyConfig,
auth_backend: &'static auth::Backend<'static, ()>,
ctx: &RequestContext,
cancellation_handler: Arc<CancellationHandler>,
stream: S,
mode: ClientMode,
client: &mut PqStream<Stream<S>>,
mode: &ClientMode,
endpoint_rate_limiter: Arc<EndpointRateLimiter>,
conn_gauge: NumClientConnectionsGuard<'static>,
cancellations: tokio_util::task::task_tracker::TaskTracker,
) -> Result<Option<ProxyPassthrough<S>>, ClientRequestError> {
debug!(
protocol = %ctx.protocol(),
"handling interactive connection from client"
);
let metrics = &Metrics::get().proxy;
let proto = ctx.protocol();
let request_gauge = metrics.connection_requests.guard(proto);
let tls = config.tls_config.load();
let tls = tls.as_deref();
let record_handshake_error = !ctx.has_private_peer_addr();
let pause = ctx.latency_timer_pause(crate::metrics::Waiting::Client);
let do_handshake = handshake(ctx, stream, mode.handshake_tls(tls), record_handshake_error);
let (mut stream, params) = match tokio::time::timeout(config.handshake_timeout, do_handshake)
.await??
{
HandshakeData::Startup(stream, params) => (stream, params),
HandshakeData::Cancel(cancel_key_data) => {
// spawn a task to cancel the session, but don't wait for it
cancellations.spawn({
let cancellation_handler_clone = Arc::clone(&cancellation_handler);
let ctx = ctx.clone();
let cancel_span = tracing::span!(parent: None, tracing::Level::INFO, "cancel_session", session_id = ?ctx.session_id());
cancel_span.follows_from(tracing::Span::current());
async move {
cancellation_handler_clone
.cancel_session(
cancel_key_data,
ctx,
config.authentication_config.ip_allowlist_check_enabled,
config.authentication_config.is_vpc_acccess_proxy,
auth_backend.get_api(),
)
.await
.inspect_err(|e | debug!(error = ?e, "cancel_session failed")).ok();
}.instrument(cancel_span)
});
return Ok(None);
}
};
drop(pause);
ctx.set_db_options(params.clone());
let hostname = mode.hostname(stream.get_ref());
let common_names = tls.map(|tls| &tls.common_names);
common_names: Option<&HashSet<String>>,
params: &StartupMessageParams,
) -> Result<(ComputeConnection, oneshot::Sender<Infallible>), ClientRequestError> {
let hostname = mode.hostname(client.get_ref());
// Extract credentials which we're going to use for auth.
let result = auth_backend
.as_ref()
.map(|()| auth::ComputeUserInfoMaybeEndpoint::parse(ctx, &params, hostname, common_names))
.map(|()| auth::ComputeUserInfoMaybeEndpoint::parse(ctx, params, hostname, common_names))
.transpose();
let user_info = match result {
Ok(user_info) => user_info,
Err(e) => Err(stream.throw_error(e, Some(ctx)).await)?,
Err(e) => Err(client.throw_error(e, Some(ctx)).await)?,
};
let user = user_info.get_user().to_owned();
let user_info = match user_info
.authenticate(
ctx,
&mut stream,
client,
mode.allow_cleartext(),
&config.authentication_config,
endpoint_rate_limiter,
@@ -339,7 +75,7 @@ pub(crate) async fn handle_client<S: AsyncRead + AsyncWrite + Unpin + Send>(
let app = params.get("application_name");
let params_span = tracing::info_span!("", ?user, ?db, ?app);
return Err(stream
return Err(client
.throw_error(e, Some(ctx))
.instrument(params_span)
.await)?;
@@ -352,37 +88,67 @@ pub(crate) async fn handle_client<S: AsyncRead + AsyncWrite + Unpin + Send>(
};
let params_compat = creds.info.options.get(NeonOptions::PARAMS_COMPAT).is_some();
let mut auth_info = compute::AuthInfo::with_auth_keys(creds.keys);
auth_info.set_startup_params(&params, params_compat);
auth_info.set_startup_params(params, params_compat);
let res = connect_to_compute(
ctx,
&TcpMechanism {
locks: &config.connect_compute_locks,
},
&auth::Backend::ControlPlane(cplane, creds.info.clone()),
config.wake_compute_retry_config,
&config.connect_to_compute,
)
.await;
let mut node = match res {
Ok(node) => node,
Err(e) => Err(stream.throw_error(e, Some(ctx)).await)?,
let mut node;
let mut attempt = 0;
let connect = TcpMechanism {
locks: &config.connect_compute_locks,
};
let backend = auth::Backend::ControlPlane(cplane, creds.info);
let pg_settings = auth_info.authenticate(ctx, &mut node, creds.info).await;
let pg_settings = match pg_settings {
Ok(pg_settings) => pg_settings,
Err(e) => Err(stream.throw_error(e, Some(ctx)).await)?,
// NOTE: This is messy, but should hopefully be detangled with PGLB.
// We wanted to separate the concerns of **connect** to compute (a PGLB operation),
// from **authenticate** to compute (a NeonKeeper operation).
//
// This unfortunately removed retry handling for one error case where
// the compute was cached, and we connected, but the compute cache was actually stale
// and is associated with the wrong endpoint. We detect this when the **authentication** fails.
// As such, we retry once here if the `authenticate` function fails and the error is valid to retry.
let pg_settings = loop {
attempt += 1;
// TODO: callback to pglb
let res = connect_to_compute(
ctx,
&connect,
&backend,
config.wake_compute_retry_config,
&config.connect_to_compute,
)
.await;
match res {
Ok(n) => node = n,
Err(e) => return Err(client.throw_error(e, Some(ctx)).await)?,
}
let auth::Backend::ControlPlane(cplane, user_info) = &backend else {
unreachable!("ensured above");
};
let res = auth_info.authenticate(ctx, &mut node, user_info).await;
match res {
Ok(pg_settings) => break pg_settings,
Err(e) if attempt < 2 && e.should_retry_wake_compute() => {
tracing::warn!(error = ?e, "retrying wake compute");
#[allow(irrefutable_let_patterns)]
if let ControlPlaneClient::ProxyV1(cplane_proxy_v1) = &**cplane {
let key = user_info.endpoint_cache_key();
cplane_proxy_v1.caches.node_info.invalidate(&key);
}
}
Err(e) => Err(client.throw_error(e, Some(ctx)).await)?,
}
};
let session = cancellation_handler.get_key();
prepare_client_connection(&pg_settings, *session.key(), &mut stream);
let stream = stream.flush_and_into_inner().await?;
finish_client_init(&pg_settings, *session.key(), client);
let session_id = ctx.session_id();
let (cancel_on_shutdown, cancel) = tokio::sync::oneshot::channel();
let (cancel_on_shutdown, cancel) = oneshot::channel();
tokio::spawn(async move {
session
.maintain_cancel_key(
@@ -394,50 +160,32 @@ pub(crate) async fn handle_client<S: AsyncRead + AsyncWrite + Unpin + Send>(
.await;
});
let private_link_id = match ctx.extra() {
Some(ConnectionInfoExtra::Aws { vpce_id }) => Some(vpce_id.clone()),
Some(ConnectionInfoExtra::Azure { link_id }) => Some(link_id.to_smolstr()),
None => None,
};
Ok(Some(ProxyPassthrough {
client: stream,
compute: node.stream,
aux: node.aux,
private_link_id,
_cancel_on_shutdown: cancel_on_shutdown,
_req: request_gauge,
_conn: conn_gauge,
_db_conn: node.guage,
}))
Ok((node, cancel_on_shutdown))
}
/// Finish client connection initialization: confirm auth success, send params, etc.
pub(crate) fn prepare_client_connection(
pub(crate) fn finish_client_init(
settings: &compute::PostgresSettings,
cancel_key_data: CancelKeyData,
stream: &mut PqStream<impl AsyncRead + AsyncWrite + Unpin>,
client: &mut PqStream<impl AsyncRead + AsyncWrite + Unpin>,
) {
// Forward all deferred notices to the client.
for notice in &settings.delayed_notice {
stream.write_raw(notice.as_bytes().len(), b'N', |buf| {
client.write_raw(notice.as_bytes().len(), b'N', |buf| {
buf.extend_from_slice(notice.as_bytes());
});
}
// Forward all postgres connection params to the client.
for (name, value) in &settings.params {
stream.write_message(BeMessage::ParameterStatus {
client.write_message(BeMessage::ParameterStatus {
name: name.as_bytes(),
value: value.as_bytes(),
});
}
stream.write_message(BeMessage::BackendKeyData(cancel_key_data));
stream.write_message(BeMessage::ReadyForQuery);
client.write_message(BeMessage::BackendKeyData(cancel_key_data));
client.write_message(BeMessage::ReadyForQuery);
}
#[derive(Debug, Clone, PartialEq, Eq, Default, Serialize, Deserialize)]
@@ -447,7 +195,7 @@ impl NeonOptions {
// proxy options:
/// `PARAMS_COMPAT` allows opting in to forwarding all startup parameters from client to compute.
const PARAMS_COMPAT: &str = "proxy_params_compat";
pub const PARAMS_COMPAT: &str = "proxy_params_compat";
// cplane options:

View File

@@ -3,7 +3,7 @@ use std::io;
use tokio::time;
use crate::compute;
use crate::compute::{self, PostgresError};
use crate::config::RetryConfig;
pub(crate) trait CouldRetry {
@@ -115,6 +115,14 @@ impl ShouldRetryWakeCompute for compute::ConnectionError {
}
}
impl ShouldRetryWakeCompute for PostgresError {
fn should_retry_wake_compute(&self) -> bool {
match self {
PostgresError::Postgres(error) => error.should_retry_wake_compute(),
}
}
}
pub(crate) fn retry_after(num_retries: u32, config: RetryConfig) -> time::Duration {
config
.base_delay

View File

@@ -14,6 +14,9 @@ use tokio::io::{AsyncReadExt, AsyncWriteExt, DuplexStream};
use tokio_util::codec::{Decoder, Encoder};
use super::*;
use crate::config::TlsConfig;
use crate::context::RequestContext;
use crate::pglb::handshake::{HandshakeData, handshake};
enum Intercept {
None,

View File

@@ -3,6 +3,7 @@
mod mitm;
use std::sync::Arc;
use std::time::Duration;
use anyhow::{Context, bail};
@@ -10,26 +11,31 @@ use async_trait::async_trait;
use http::StatusCode;
use postgres_client::config::SslMode;
use postgres_client::tls::{MakeTlsConnect, NoTls};
use retry::{ShouldRetryWakeCompute, retry_after};
use rstest::rstest;
use rustls::crypto::ring;
use rustls::pki_types;
use tokio::io::DuplexStream;
use tokio::io::{AsyncRead, AsyncWrite, DuplexStream};
use tracing_test::traced_test;
use super::retry::CouldRetry;
use super::*;
use crate::auth::backend::{ComputeUserInfo, MaybeOwned};
use crate::config::{ComputeConfig, RetryConfig};
use crate::config::{ComputeConfig, RetryConfig, TlsConfig};
use crate::context::RequestContext;
use crate::control_plane::client::{ControlPlaneClient, TestControlPlaneClient};
use crate::control_plane::messages::{ControlPlaneErrorMessage, Details, MetricsAuxInfo, Status};
use crate::control_plane::{self, CachedNodeInfo, NodeInfo, NodeInfoCache};
use crate::error::ErrorKind;
use crate::proxy::connect_compute::ConnectMechanism;
use crate::error::{ErrorKind, ReportableError};
use crate::pglb::ERR_INSECURE_CONNECTION;
use crate::pglb::handshake::{HandshakeData, handshake};
use crate::pqproto::BeMessage;
use crate::proxy::NeonOptions;
use crate::proxy::connect_compute::{ConnectMechanism, connect_to_compute};
use crate::proxy::retry::{ShouldRetryWakeCompute, retry_after};
use crate::stream::{PqStream, Stream};
use crate::tls::client_config::compute_client_config_with_certs;
use crate::tls::server_config::CertResolver;
use crate::types::{BranchId, EndpointId, ProjectId};
use crate::{sasl, scram};
use crate::{auth, compute, sasl, scram};
/// Generate a set of TLS certificates: CA + server.
fn generate_certs(
@@ -374,6 +380,7 @@ fn connect_compute_total_wait() {
#[derive(Clone, Copy, Debug)]
enum ConnectAction {
Wake,
WakeCold,
WakeFail,
WakeRetry,
Connect,
@@ -504,6 +511,9 @@ impl TestControlPlaneClient for TestConnectMechanism {
*counter += 1;
match action {
ConnectAction::Wake => Ok(helper_create_cached_node_info(self.cache)),
ConnectAction::WakeCold => Ok(CachedNodeInfo::new_uncached(
helper_create_uncached_node_info(),
)),
ConnectAction::WakeFail => {
let err = control_plane::errors::ControlPlaneError::Message(Box::new(
ControlPlaneErrorMessage {
@@ -551,8 +561,8 @@ impl TestControlPlaneClient for TestConnectMechanism {
}
}
fn helper_create_cached_node_info(cache: &'static NodeInfoCache) -> CachedNodeInfo {
let node = NodeInfo {
fn helper_create_uncached_node_info() -> NodeInfo {
NodeInfo {
conn_info: compute::ConnectInfo {
host: "test".into(),
port: 5432,
@@ -566,7 +576,11 @@ fn helper_create_cached_node_info(cache: &'static NodeInfoCache) -> CachedNodeIn
compute_id: "compute".into(),
cold_start_info: crate::control_plane::messages::ColdStartInfo::Warm,
},
};
}
}
fn helper_create_cached_node_info(cache: &'static NodeInfoCache) -> CachedNodeInfo {
let node = helper_create_uncached_node_info();
let (_, node2) = cache.insert_unit("key".into(), Ok(node.clone()));
node2.map(|()| node)
}
@@ -742,7 +756,7 @@ async fn fail_no_wake_skips_cache_invalidation() {
let ctx = RequestContext::test();
let mech = TestConnectMechanism::new(vec![
ConnectAction::Wake,
ConnectAction::FailNoWake,
ConnectAction::RetryNoWake,
ConnectAction::Connect,
]);
let user = helper_create_connect_info(&mech);
@@ -788,7 +802,7 @@ async fn retry_no_wake_skips_invalidation() {
let ctx = RequestContext::test();
// Wake → RetryNoWake (retryable + NOT wakeable)
let mechanism = TestConnectMechanism::new(vec![Wake, RetryNoWake]);
let mechanism = TestConnectMechanism::new(vec![Wake, RetryNoWake, Fail]);
let user_info = helper_create_connect_info(&mechanism);
let cfg = config();
@@ -802,3 +816,44 @@ async fn retry_no_wake_skips_invalidation() {
"invalidating stalled compute node info cache entry"
));
}
#[tokio::test]
#[traced_test]
async fn retry_no_wake_error_fast() {
let _ = env_logger::try_init();
use ConnectAction::*;
let ctx = RequestContext::test();
// Wake → FailNoWake (not retryable + NOT wakeable)
let mechanism = TestConnectMechanism::new(vec![Wake, FailNoWake]);
let user_info = helper_create_connect_info(&mechanism);
let cfg = config();
connect_to_compute(&ctx, &mechanism, &user_info, cfg.retry, &cfg)
.await
.unwrap_err();
mechanism.verify();
// Because FailNoWake has wakeable=false, we must NOT see invalidate_cache
assert!(!logs_contain(
"invalidating stalled compute node info cache entry"
));
}
#[tokio::test]
#[traced_test]
async fn retry_cold_wake_skips_invalidation() {
let _ = env_logger::try_init();
use ConnectAction::*;
let ctx = RequestContext::test();
// WakeCold → FailNoWake (not retryable + NOT wakeable)
let mechanism = TestConnectMechanism::new(vec![WakeCold, Retry, Connect]);
let user_info = helper_create_connect_info(&mechanism);
let cfg = config();
connect_to_compute(&ctx, &mechanism, &user_info, cfg.retry, &cfg)
.await
.unwrap();
mechanism.verify();
}

View File

@@ -139,12 +139,6 @@ impl RateBucketInfo {
Self::new(200, Duration::from_secs(600)),
];
// For all the sessions will be cancel key. So this limit is essentially global proxy limit.
pub const DEFAULT_REDIS_SET: [Self; 2] = [
Self::new(100_000, Duration::from_secs(1)),
Self::new(50_000, Duration::from_secs(10)),
];
pub fn rps(&self) -> f64 {
(self.max_rpi as f64) / self.interval.as_secs_f64()
}

View File

@@ -5,11 +5,9 @@ use redis::aio::ConnectionLike;
use redis::{Cmd, FromRedisValue, Pipeline, RedisResult};
use super::connection_with_credentials_provider::ConnectionWithCredentialsProvider;
use crate::rate_limiter::{GlobalRateLimiter, RateBucketInfo};
pub struct RedisKVClient {
client: ConnectionWithCredentialsProvider,
limiter: GlobalRateLimiter,
}
#[allow(async_fn_in_trait)]
@@ -30,11 +28,8 @@ impl Queryable for Cmd {
}
impl RedisKVClient {
pub fn new(client: ConnectionWithCredentialsProvider, info: &'static [RateBucketInfo]) -> Self {
Self {
client,
limiter: GlobalRateLimiter::new(info.into()),
}
pub fn new(client: ConnectionWithCredentialsProvider) -> Self {
Self { client }
}
pub async fn try_connect(&mut self) -> anyhow::Result<()> {
@@ -49,11 +44,6 @@ impl RedisKVClient {
&mut self,
q: &impl Queryable,
) -> anyhow::Result<T> {
if !self.limiter.check() {
tracing::info!("Rate limit exceeded. Skipping query");
return Err(anyhow::anyhow!("Rate limit exceeded"));
}
let e = match q.query(&mut self.client).await {
Ok(t) => return Ok(t),
Err(e) => e,

View File

@@ -141,29 +141,19 @@ where
struct MessageHandler<C: ProjectInfoCache + Send + Sync + 'static> {
cache: Arc<C>,
region_id: String,
}
impl<C: ProjectInfoCache + Send + Sync + 'static> Clone for MessageHandler<C> {
fn clone(&self) -> Self {
Self {
cache: self.cache.clone(),
region_id: self.region_id.clone(),
}
}
}
impl<C: ProjectInfoCache + Send + Sync + 'static> MessageHandler<C> {
pub(crate) fn new(cache: Arc<C>, region_id: String) -> Self {
Self { cache, region_id }
}
pub(crate) async fn increment_active_listeners(&self) {
self.cache.increment_active_listeners().await;
}
pub(crate) async fn decrement_active_listeners(&self) {
self.cache.decrement_active_listeners().await;
pub(crate) fn new(cache: Arc<C>) -> Self {
Self { cache }
}
#[tracing::instrument(skip(self, msg), fields(session_id = tracing::field::Empty))]
@@ -276,7 +266,7 @@ async fn handle_messages<C: ProjectInfoCache + Send + Sync + 'static>(
}
let mut conn = match try_connect(&redis).await {
Ok(conn) => {
handler.increment_active_listeners().await;
handler.cache.increment_active_listeners().await;
conn
}
Err(e) => {
@@ -297,11 +287,11 @@ async fn handle_messages<C: ProjectInfoCache + Send + Sync + 'static>(
}
}
if cancellation_token.is_cancelled() {
handler.decrement_active_listeners().await;
handler.cache.decrement_active_listeners().await;
return Ok(());
}
}
handler.decrement_active_listeners().await;
handler.cache.decrement_active_listeners().await;
}
}
@@ -310,12 +300,11 @@ async fn handle_messages<C: ProjectInfoCache + Send + Sync + 'static>(
pub async fn task_main<C>(
redis: ConnectionWithCredentialsProvider,
cache: Arc<C>,
region_id: String,
) -> anyhow::Result<Infallible>
where
C: ProjectInfoCache + Send + Sync + 'static,
{
let handler = MessageHandler::new(cache, region_id);
let handler = MessageHandler::new(cache);
// 6h - 1m.
// There will be 1 minute overlap between two tasks. But at least we can be sure that no message is lost.
let mut interval = tokio::time::interval(std::time::Duration::from_secs(6 * 60 * 60 - 60));

View File

@@ -418,12 +418,7 @@ async fn request_handler(
if config.http_config.accept_websockets
&& framed_websockets::upgrade::is_upgrade_request(&request)
{
let ctx = RequestContext::new(
session_id,
conn_info,
crate::metrics::Protocol::Ws,
&config.region,
);
let ctx = RequestContext::new(session_id, conn_info, crate::metrics::Protocol::Ws);
ctx.set_user_agent(
request
@@ -463,12 +458,7 @@ async fn request_handler(
// Return the response so the spawned future can continue.
Ok(response.map(|b| b.map_err(|x| match x {}).boxed()))
} else if request.uri().path() == "/sql" && *request.method() == Method::POST {
let ctx = RequestContext::new(
session_id,
conn_info,
crate::metrics::Protocol::Http,
&config.region,
);
let ctx = RequestContext::new(session_id, conn_info, crate::metrics::Protocol::Http);
let span = ctx.span();
let testodrome_id = request

View File

@@ -17,7 +17,8 @@ use crate::config::ProxyConfig;
use crate::context::RequestContext;
use crate::error::ReportableError;
use crate::metrics::Metrics;
use crate::proxy::{ClientMode, ErrorSource, handle_client};
use crate::pglb::{ClientMode, handle_connection};
use crate::proxy::ErrorSource;
use crate::rate_limiter::EndpointRateLimiter;
pin_project! {
@@ -142,7 +143,7 @@ pub(crate) async fn serve_websocket(
.client_connections
.guard(crate::metrics::Protocol::Ws);
let res = Box::pin(handle_client(
let res = Box::pin(handle_connection(
config,
auth_backend,
&ctx,

View File

@@ -1,5 +1,5 @@
[toolchain]
channel = "1.87.0"
channel = "1.88.0"
profile = "default"
# The default profile includes rustc, rust-std, cargo, rust-docs, rustfmt and clippy.
# https://rust-lang.github.io/rustup/concepts/profiles.html

View File

@@ -220,7 +220,7 @@ impl<IO: AsyncRead + AsyncWrite + Unpin + Send> postgres_backend::Handler<IO>
stripe_size: ShardStripeSize(stripe_size),
};
self.shard =
Some(ShardIdentity::from_params(ShardNumber(number), &params));
Some(ShardIdentity::from_params(ShardNumber(number), params));
}
_ => {
return Err(QueryError::Other(anyhow::anyhow!(

View File

@@ -1,5 +1,6 @@
use std::cmp::min;
use std::io::{self, ErrorKind};
use std::ops::RangeInclusive;
use std::sync::Arc;
use anyhow::{Context, Result, anyhow, bail};
@@ -34,7 +35,7 @@ use crate::control_file::CONTROL_FILE_NAME;
use crate::state::{EvictionState, TimelinePersistentState};
use crate::timeline::{Timeline, TimelineError, WalResidentTimeline};
use crate::timelines_global_map::{create_temp_timeline_dir, validate_temp_timeline};
use crate::wal_storage::open_wal_file;
use crate::wal_storage::{open_wal_file, wal_file_paths};
use crate::{GlobalTimelines, debug_dump, wal_backup};
/// Stream tar archive of timeline to tx.
@@ -95,8 +96,8 @@ pub async fn stream_snapshot(
/// State needed while streaming the snapshot.
pub struct SnapshotContext {
pub from_segno: XLogSegNo, // including
pub upto_segno: XLogSegNo, // including
/// The interval of segment numbers. If None, the timeline hasn't had writes yet, so only send the control file
pub from_to_segno: Option<RangeInclusive<XLogSegNo>>,
pub term: Term,
pub last_log_term: Term,
pub flush_lsn: Lsn,
@@ -174,23 +175,35 @@ pub async fn stream_snapshot_resident_guts(
.await?;
pausable_failpoint!("sk-snapshot-after-list-pausable");
let tli_dir = tli.get_timeline_dir();
info!(
"sending {} segments [{:#X}-{:#X}], term={}, last_log_term={}, flush_lsn={}",
bctx.upto_segno - bctx.from_segno + 1,
bctx.from_segno,
bctx.upto_segno,
bctx.term,
bctx.last_log_term,
bctx.flush_lsn,
);
for segno in bctx.from_segno..=bctx.upto_segno {
let (mut sf, is_partial) = open_wal_file(&tli_dir, segno, bctx.wal_seg_size).await?;
let mut wal_file_name = XLogFileName(PG_TLI, segno, bctx.wal_seg_size);
if is_partial {
wal_file_name.push_str(".partial");
if let Some(from_to_segno) = &bctx.from_to_segno {
let tli_dir = tli.get_timeline_dir();
info!(
"sending {} segments [{:#X}-{:#X}], term={}, last_log_term={}, flush_lsn={}",
from_to_segno.end() - from_to_segno.start() + 1,
from_to_segno.start(),
from_to_segno.end(),
bctx.term,
bctx.last_log_term,
bctx.flush_lsn,
);
for segno in from_to_segno.clone() {
let Some((mut sf, is_partial)) =
open_wal_file(&tli_dir, segno, bctx.wal_seg_size).await?
else {
// File is not found
let (wal_file_path, _wal_file_partial_path) =
wal_file_paths(&tli_dir, segno, bctx.wal_seg_size);
tracing::warn!("couldn't find WAL segment file {wal_file_path}");
bail!("couldn't find WAL segment file {wal_file_path}")
};
let mut wal_file_name = XLogFileName(PG_TLI, segno, bctx.wal_seg_size);
if is_partial {
wal_file_name.push_str(".partial");
}
ar.append_file(&wal_file_name, &mut sf).await?;
}
ar.append_file(&wal_file_name, &mut sf).await?;
} else {
info!("Not including any segments into the snapshot");
}
// Do the term check before ar.finish to make archive corrupted in case of
@@ -338,19 +351,26 @@ impl WalResidentTimeline {
// removed further than `backup_lsn`. Since we're holding shared_state
// lock and setting `wal_removal_on_hold` later, it guarantees that WAL
// won't be removed until we're done.
let timeline_state = shared_state.sk.state();
let from_lsn = min(
shared_state.sk.state().remote_consistent_lsn,
shared_state.sk.state().backup_lsn,
timeline_state.remote_consistent_lsn,
timeline_state.backup_lsn,
);
let flush_lsn = shared_state.sk.flush_lsn();
let (send_segments, msg) = if from_lsn == Lsn::INVALID {
(false, "snapshot is called on uninitialized timeline")
} else {
(true, "timeline is initialized")
};
tracing::info!(
remote_consistent_lsn=%timeline_state.remote_consistent_lsn,
backup_lsn=%timeline_state.backup_lsn,
%flush_lsn,
"{msg}"
);
if from_lsn == Lsn::INVALID {
// this is possible if snapshot is called before handling first
// elected message
bail!("snapshot is called on uninitialized timeline");
}
let from_segno = from_lsn.segment_number(wal_seg_size);
let term = shared_state.sk.state().acceptor_state.term;
let last_log_term = shared_state.sk.last_log_term();
let flush_lsn = shared_state.sk.flush_lsn();
let upto_segno = flush_lsn.segment_number(wal_seg_size);
// have some limit on max number of segments as a sanity check
const MAX_ALLOWED_SEGS: u64 = 1000;
@@ -376,9 +396,9 @@ impl WalResidentTimeline {
drop(shared_state);
let tli_copy = self.wal_residence_guard().await?;
let from_to_segno = send_segments.then_some(from_segno..=upto_segno);
let bctx = SnapshotContext {
from_segno,
upto_segno,
from_to_segno,
term,
last_log_term,
flush_lsn,

View File

@@ -9,7 +9,7 @@ use anyhow::{Result, bail};
use postgres_ffi::WAL_SEGMENT_SIZE;
use postgres_versioninfo::{PgMajorVersion, PgVersionId};
use safekeeper_api::membership::Configuration;
use safekeeper_api::models::{TimelineMembershipSwitchResponse, TimelineTermBumpResponse};
use safekeeper_api::models::TimelineTermBumpResponse;
use safekeeper_api::{INITIAL_TERM, ServerInfo, Term};
use serde::{Deserialize, Serialize};
use tracing::info;
@@ -83,6 +83,11 @@ pub enum EvictionState {
Offloaded(Lsn),
}
pub struct MembershipSwitchResult {
pub previous_conf: Configuration,
pub current_conf: Configuration,
}
impl TimelinePersistentState {
/// commit_lsn is the same as start_lsn in the normal creaiton; see
/// `TimelineCreateRequest` comments.`
@@ -261,10 +266,7 @@ where
/// Switch into membership configuration `to` if it is higher than the
/// current one.
pub async fn membership_switch(
&mut self,
to: Configuration,
) -> Result<TimelineMembershipSwitchResponse> {
pub async fn membership_switch(&mut self, to: Configuration) -> Result<MembershipSwitchResult> {
let before = self.mconf.clone();
// Is switch allowed?
if to.generation <= self.mconf.generation {
@@ -278,7 +280,7 @@ where
self.finish_change(&state).await?;
info!("switched membership conf to {} from {}", to, before);
}
Ok(TimelineMembershipSwitchResponse {
Ok(MembershipSwitchResult {
previous_conf: before,
current_conf: self.mconf.clone(),
})

View File

@@ -190,7 +190,14 @@ impl StateSK {
&mut self,
to: Configuration,
) -> Result<TimelineMembershipSwitchResponse> {
self.state_mut().membership_switch(to).await
let result = self.state_mut().membership_switch(to).await?;
Ok(TimelineMembershipSwitchResponse {
previous_conf: result.previous_conf,
current_conf: result.current_conf,
term: self.state().acceptor_state.term,
flush_lsn: self.flush_lsn(),
})
}
/// Close open WAL files to release FDs.

View File

@@ -9,7 +9,7 @@
use std::cmp::{max, min};
use std::future::Future;
use std::io::{self, SeekFrom};
use std::io::{ErrorKind, SeekFrom};
use std::pin::Pin;
use anyhow::{Context, Result, bail};
@@ -154,8 +154,8 @@ pub struct PhysicalStorage {
/// record
///
/// Partial segment 002 has no WAL records, and it will be removed by the
/// next truncate_wal(). This flag will be set to true after the first
/// truncate_wal() call.
/// next truncate_wal(). This flag will be set to false after the first
/// successful truncate_wal() call.
///
/// [`write_lsn`]: Self::write_lsn
pending_wal_truncation: bool,
@@ -202,6 +202,8 @@ impl PhysicalStorage {
ttid.timeline_id, flush_lsn, state.commit_lsn, state.peer_horizon_lsn,
);
if flush_lsn < state.commit_lsn {
// note: can never happen. find_end_of_wal returns provided start_lsn
// (state.commit_lsn in our case) if it doesn't find anything.
bail!(
"timeline {} potential data loss: flush_lsn {} by find_end_of_wal is less than commit_lsn {} from control file",
ttid.timeline_id,
@@ -794,26 +796,13 @@ impl WalReader {
// Try to open local file, if we may have WAL locally
if self.pos >= self.local_start_lsn {
let res = open_wal_file(&self.timeline_dir, segno, self.wal_seg_size).await;
match res {
Ok((mut file, _)) => {
file.seek(SeekFrom::Start(xlogoff as u64)).await?;
return Ok(Box::pin(file));
}
Err(e) => {
let is_not_found = e.chain().any(|e| {
if let Some(e) = e.downcast_ref::<io::Error>() {
e.kind() == io::ErrorKind::NotFound
} else {
false
}
});
if !is_not_found {
return Err(e);
}
// NotFound is expected, fall through to remote read
}
};
let res = open_wal_file(&self.timeline_dir, segno, self.wal_seg_size).await?;
if let Some((mut file, _)) = res {
file.seek(SeekFrom::Start(xlogoff as u64)).await?;
return Ok(Box::pin(file));
} else {
// NotFound is expected, fall through to remote read
}
}
// Try to open remote file, if remote reads are enabled
@@ -832,26 +821,31 @@ pub(crate) async fn open_wal_file(
timeline_dir: &Utf8Path,
segno: XLogSegNo,
wal_seg_size: usize,
) -> Result<(tokio::fs::File, bool)> {
) -> Result<Option<(tokio::fs::File, bool)>> {
let (wal_file_path, wal_file_partial_path) = wal_file_paths(timeline_dir, segno, wal_seg_size);
// First try to open the .partial file.
let mut partial_path = wal_file_path.to_owned();
partial_path.set_extension("partial");
if let Ok(opened_file) = tokio::fs::File::open(&wal_file_partial_path).await {
return Ok((opened_file, true));
return Ok(Some((opened_file, true)));
}
// If that failed, try it without the .partial extension.
let pf = tokio::fs::File::open(&wal_file_path)
.await
let pf_res = tokio::fs::File::open(&wal_file_path).await;
if let Err(e) = &pf_res {
if e.kind() == ErrorKind::NotFound {
return Ok(None);
}
}
let pf = pf_res
.with_context(|| format!("failed to open WAL file {wal_file_path:#}"))
.map_err(|e| {
warn!("{}", e);
warn!("{e}");
e
})?;
Ok((pf, false))
Ok(Some((pf, false)))
}
/// Helper returning full path to WAL segment file and its .partial brother.

View File

@@ -20,6 +20,7 @@ camino.workspace = true
chrono.workspace = true
clap.workspace = true
clashmap.workspace = true
compute_api.workspace = true
cron.workspace = true
fail.workspace = true
futures.workspace = true

View File

@@ -5,7 +5,8 @@ use std::sync::Arc;
use std::time::Duration;
use anyhow::Context;
use control_plane::endpoint::{ComputeControlPlane, EndpointStatus, PageserverProtocol};
use compute_api::spec::PageserverProtocol;
use control_plane::endpoint::{ComputeControlPlane, EndpointStatus};
use control_plane::local_env::LocalEnv;
use futures::StreamExt;
use hyper::StatusCode;
@@ -13,11 +14,12 @@ use pageserver_api::config::DEFAULT_GRPC_LISTEN_PORT;
use pageserver_api::controller_api::AvailabilityZone;
use pageserver_api::shard::{ShardCount, ShardNumber, ShardStripeSize, TenantShardId};
use postgres_connection::parse_host_port;
use safekeeper_api::membership::SafekeeperGeneration;
use serde::{Deserialize, Serialize};
use tokio_util::sync::CancellationToken;
use tracing::{Instrument, info_span};
use utils::backoff::{self};
use utils::id::{NodeId, TenantId};
use utils::id::{NodeId, TenantId, TenantTimelineId, TimelineId};
use crate::service::Config;
@@ -35,7 +37,7 @@ struct UnshardedComputeHookTenant {
preferred_az: Option<AvailabilityZone>,
// Must hold this lock to send a notification.
send_lock: Arc<tokio::sync::Mutex<Option<ComputeRemoteState>>>,
send_lock: Arc<tokio::sync::Mutex<Option<ComputeRemoteTenantState>>>,
}
struct ShardedComputeHookTenant {
stripe_size: ShardStripeSize,
@@ -48,7 +50,7 @@ struct ShardedComputeHookTenant {
// Must hold this lock to send a notification. The contents represent
// the last successfully sent notification, and are used to coalesce multiple
// updates by only sending when there is a chance since our last successful send.
send_lock: Arc<tokio::sync::Mutex<Option<ComputeRemoteState>>>,
send_lock: Arc<tokio::sync::Mutex<Option<ComputeRemoteTenantState>>>,
}
/// Represents our knowledge of the compute's state: we can update this when we get a
@@ -56,9 +58,9 @@ struct ShardedComputeHookTenant {
///
/// Should be wrapped in an Option<>, as we cannot always know the remote state.
#[derive(PartialEq, Eq, Debug)]
struct ComputeRemoteState {
struct ComputeRemoteState<R> {
// The request body which was acked by the compute
request: ComputeHookNotifyRequest,
request: R,
// Whether the cplane indicated that the state was applied to running computes, or just
// persisted. In the Neon control plane, this is the difference between a 423 response (meaning
@@ -66,6 +68,36 @@ struct ComputeRemoteState {
applied: bool,
}
type ComputeRemoteTenantState = ComputeRemoteState<NotifyAttachRequest>;
type ComputeRemoteTimelineState = ComputeRemoteState<NotifySafekeepersRequest>;
/// The trait which define the handler-specific types and methods.
/// We have two implementations of this trait so far:
/// - [`ComputeHookTenant`] for tenant attach notifications ("/notify-attach")
/// - [`ComputeHookTimeline`] for safekeeper change notifications ("/notify-safekeepers")
trait ApiMethod {
/// Type of the key which identifies the resource.
/// It's either TenantId for tenant attach notifications,
/// or TenantTimelineId for safekeeper change notifications.
type Key: std::cmp::Eq + std::hash::Hash + Clone;
type Request: serde::Serialize + std::fmt::Debug;
const API_PATH: &'static str;
fn maybe_send(
&self,
key: Self::Key,
lock: Option<tokio::sync::OwnedMutexGuard<Option<ComputeRemoteState<Self::Request>>>>,
) -> MaybeSendResult<Self::Request, Self::Key>;
async fn notify_local(
env: &LocalEnv,
cplane: &ComputeControlPlane,
req: &Self::Request,
) -> Result<(), NotifyError>;
}
enum ComputeHookTenant {
Unsharded(UnshardedComputeHookTenant),
Sharded(ShardedComputeHookTenant),
@@ -96,7 +128,7 @@ impl ComputeHookTenant {
}
}
fn get_send_lock(&self) -> &Arc<tokio::sync::Mutex<Option<ComputeRemoteState>>> {
fn get_send_lock(&self) -> &Arc<tokio::sync::Mutex<Option<ComputeRemoteTenantState>>> {
match self {
Self::Unsharded(unsharded_tenant) => &unsharded_tenant.send_lock,
Self::Sharded(sharded_tenant) => &sharded_tenant.send_lock,
@@ -190,19 +222,136 @@ impl ComputeHookTenant {
}
}
/// The state of a timeline we need to notify the compute about.
struct ComputeHookTimeline {
generation: SafekeeperGeneration,
safekeepers: Vec<SafekeeperInfo>,
send_lock: Arc<tokio::sync::Mutex<Option<ComputeRemoteTimelineState>>>,
}
impl ComputeHookTimeline {
/// Construct a new ComputeHookTimeline with the given safekeepers and generation.
fn new(generation: SafekeeperGeneration, safekeepers: Vec<SafekeeperInfo>) -> Self {
Self {
generation,
safekeepers,
send_lock: Arc::default(),
}
}
/// Update the state with a new SafekeepersUpdate.
/// Noop if the update generation is not greater than the current generation.
fn update(&mut self, sk_update: SafekeepersUpdate) {
if sk_update.generation > self.generation {
self.generation = sk_update.generation;
self.safekeepers = sk_update.safekeepers;
}
}
}
impl ApiMethod for ComputeHookTimeline {
type Key = TenantTimelineId;
type Request = NotifySafekeepersRequest;
const API_PATH: &'static str = "notify-safekeepers";
fn maybe_send(
&self,
ttid: TenantTimelineId,
lock: Option<tokio::sync::OwnedMutexGuard<Option<ComputeRemoteTimelineState>>>,
) -> MaybeSendNotifySafekeepersResult {
let locked = match lock {
Some(already_locked) => already_locked,
None => {
// Lock order: this _must_ be only a try_lock, because we are called inside of the [`ComputeHook::timelines`] lock.
let Ok(locked) = self.send_lock.clone().try_lock_owned() else {
return MaybeSendResult::AwaitLock((ttid, self.send_lock.clone()));
};
locked
}
};
if locked
.as_ref()
.is_some_and(|s| s.request.generation >= self.generation)
{
return MaybeSendResult::Noop;
}
MaybeSendResult::Transmit((
NotifySafekeepersRequest {
tenant_id: ttid.tenant_id,
timeline_id: ttid.timeline_id,
generation: self.generation,
safekeepers: self.safekeepers.clone(),
},
locked,
))
}
async fn notify_local(
_env: &LocalEnv,
cplane: &ComputeControlPlane,
req: &NotifySafekeepersRequest,
) -> Result<(), NotifyError> {
let NotifySafekeepersRequest {
tenant_id,
timeline_id,
generation,
safekeepers,
} = req;
for (endpoint_name, endpoint) in &cplane.endpoints {
if endpoint.tenant_id == *tenant_id
&& endpoint.timeline_id == *timeline_id
&& endpoint.status() == EndpointStatus::Running
{
tracing::info!("Reconfiguring safekeepers for endpoint {endpoint_name}");
let safekeepers = safekeepers.iter().map(|sk| sk.id).collect::<Vec<_>>();
endpoint
.reconfigure_safekeepers(safekeepers, *generation)
.await
.map_err(NotifyError::NeonLocal)?;
}
}
Ok(())
}
}
#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
struct ComputeHookNotifyRequestShard {
struct NotifyAttachRequestShard {
node_id: NodeId,
shard_number: ShardNumber,
}
/// Request body that we send to the control plane to notify it of where a tenant is attached
#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
struct ComputeHookNotifyRequest {
struct NotifyAttachRequest {
tenant_id: TenantId,
preferred_az: Option<String>,
stripe_size: Option<ShardStripeSize>,
shards: Vec<ComputeHookNotifyRequestShard>,
shards: Vec<NotifyAttachRequestShard>,
}
#[derive(Serialize, Deserialize, Debug, Eq, PartialEq, Clone)]
pub(crate) struct SafekeeperInfo {
pub id: NodeId,
/// Hostname of the safekeeper.
/// It exists for better debuggability. Might be missing.
/// Should not be used for anything else.
pub hostname: Option<String>,
}
#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
struct NotifySafekeepersRequest {
tenant_id: TenantId,
timeline_id: TimelineId,
generation: SafekeeperGeneration,
safekeepers: Vec<SafekeeperInfo>,
}
/// Error type for attempts to call into the control plane compute notification hook
@@ -234,42 +383,50 @@ pub(crate) enum NotifyError {
NeonLocal(anyhow::Error),
}
enum MaybeSendResult {
enum MaybeSendResult<R, K> {
// Please send this request while holding the lock, and if you succeed then write
// the request into the lock.
Transmit(
(
ComputeHookNotifyRequest,
tokio::sync::OwnedMutexGuard<Option<ComputeRemoteState>>,
R,
tokio::sync::OwnedMutexGuard<Option<ComputeRemoteState<R>>>,
),
),
// Something requires sending, but you must wait for a current sender then call again
AwaitLock(Arc<tokio::sync::Mutex<Option<ComputeRemoteState>>>),
AwaitLock((K, Arc<tokio::sync::Mutex<Option<ComputeRemoteState<R>>>>)),
// Nothing requires sending
Noop,
}
impl ComputeHookTenant {
type MaybeSendNotifyAttachResult = MaybeSendResult<NotifyAttachRequest, TenantId>;
type MaybeSendNotifySafekeepersResult = MaybeSendResult<NotifySafekeepersRequest, TenantTimelineId>;
impl ApiMethod for ComputeHookTenant {
type Key = TenantId;
type Request = NotifyAttachRequest;
const API_PATH: &'static str = "notify-attach";
fn maybe_send(
&self,
tenant_id: TenantId,
lock: Option<tokio::sync::OwnedMutexGuard<Option<ComputeRemoteState>>>,
) -> MaybeSendResult {
lock: Option<tokio::sync::OwnedMutexGuard<Option<ComputeRemoteTenantState>>>,
) -> MaybeSendNotifyAttachResult {
let locked = match lock {
Some(already_locked) => already_locked,
None => {
// Lock order: this _must_ be only a try_lock, because we are called inside of the [`ComputeHook::state`] lock.
// Lock order: this _must_ be only a try_lock, because we are called inside of the [`ComputeHook::tenants`] lock.
let Ok(locked) = self.get_send_lock().clone().try_lock_owned() else {
return MaybeSendResult::AwaitLock(self.get_send_lock().clone());
return MaybeSendResult::AwaitLock((tenant_id, self.get_send_lock().clone()));
};
locked
}
};
let request = match self {
Self::Unsharded(unsharded_tenant) => Some(ComputeHookNotifyRequest {
Self::Unsharded(unsharded_tenant) => Some(NotifyAttachRequest {
tenant_id,
shards: vec![ComputeHookNotifyRequestShard {
shards: vec![NotifyAttachRequestShard {
shard_number: ShardNumber(0),
node_id: unsharded_tenant.node_id,
}],
@@ -282,12 +439,12 @@ impl ComputeHookTenant {
Self::Sharded(sharded_tenant)
if sharded_tenant.shards.len() == sharded_tenant.shard_count.count() as usize =>
{
Some(ComputeHookNotifyRequest {
Some(NotifyAttachRequest {
tenant_id,
shards: sharded_tenant
.shards
.iter()
.map(|(shard_number, node_id)| ComputeHookNotifyRequestShard {
.map(|(shard_number, node_id)| NotifyAttachRequestShard {
shard_number: *shard_number,
node_id: *node_id,
})
@@ -332,98 +489,22 @@ impl ComputeHookTenant {
}
}
}
}
/// The compute hook is a destination for notifications about changes to tenant:pageserver
/// mapping. It aggregates updates for the shards in a tenant, and when appropriate reconfigures
/// the compute connection string.
pub(super) struct ComputeHook {
config: Config,
state: std::sync::Mutex<HashMap<TenantId, ComputeHookTenant>>,
authorization_header: Option<String>,
// Concurrency limiter, so that we do not overload the cloud control plane when updating
// large numbers of tenants (e.g. when failing over after a node failure)
api_concurrency: tokio::sync::Semaphore,
// This lock is only used in testing enviroments, to serialize calls into neon_lock
neon_local_lock: tokio::sync::Mutex<()>,
// We share a client across all notifications to enable connection re-use etc when
// sending large numbers of notifications
client: reqwest::Client,
}
/// Callers may give us a list of these when asking us to send a bulk batch
/// of notifications in the background. This is a 'notification' in the sense of
/// other code notifying us of a shard's status, rather than being the final notification
/// that we send upwards to the control plane for the whole tenant.
pub(crate) struct ShardUpdate<'a> {
pub(crate) tenant_shard_id: TenantShardId,
pub(crate) node_id: NodeId,
pub(crate) stripe_size: ShardStripeSize,
pub(crate) preferred_az: Option<Cow<'a, AvailabilityZone>>,
}
impl ComputeHook {
pub(super) fn new(config: Config) -> anyhow::Result<Self> {
let authorization_header = config
.control_plane_jwt_token
.clone()
.map(|jwt| format!("Bearer {jwt}"));
let mut client = reqwest::ClientBuilder::new().timeout(NOTIFY_REQUEST_TIMEOUT);
for cert in &config.ssl_ca_certs {
client = client.add_root_certificate(cert.clone());
}
let client = client
.build()
.context("Failed to build http client for compute hook")?;
Ok(Self {
state: Default::default(),
config,
authorization_header,
neon_local_lock: Default::default(),
api_concurrency: tokio::sync::Semaphore::new(API_CONCURRENCY),
client,
})
}
/// For test environments: use neon_local's LocalEnv to update compute
async fn do_notify_local(
&self,
reconfigure_request: &ComputeHookNotifyRequest,
async fn notify_local(
env: &LocalEnv,
cplane: &ComputeControlPlane,
req: &NotifyAttachRequest,
) -> Result<(), NotifyError> {
// neon_local updates are not safe to call concurrently, use a lock to serialize
// all calls to this function
let _locked = self.neon_local_lock.lock().await;
let Some(repo_dir) = self.config.neon_local_repo_dir.as_deref() else {
tracing::warn!(
"neon_local_repo_dir not set, likely a bug in neon_local; skipping compute update"
);
return Ok(());
};
let env = match LocalEnv::load_config(repo_dir) {
Ok(e) => e,
Err(e) => {
tracing::warn!("Couldn't load neon_local config, skipping compute update ({e})");
return Ok(());
}
};
let cplane =
ComputeControlPlane::load(env.clone()).expect("Error loading compute control plane");
let ComputeHookNotifyRequest {
let NotifyAttachRequest {
tenant_id,
shards,
stripe_size,
preferred_az: _preferred_az,
} = reconfigure_request;
} = req;
for (endpoint_name, endpoint) in &cplane.endpoints {
if endpoint.tenant_id == *tenant_id && endpoint.status() == EndpointStatus::Running {
tracing::info!("Reconfiguring endpoint {endpoint_name}");
tracing::info!("Reconfiguring pageservers for endpoint {endpoint_name}");
let pageservers = shards
.iter()
@@ -445,7 +526,7 @@ impl ComputeHook {
.collect::<Vec<_>>();
endpoint
.reconfigure(pageservers, *stripe_size, None)
.reconfigure_pageservers(pageservers, *stripe_size)
.await
.map_err(NotifyError::NeonLocal)?;
}
@@ -453,11 +534,102 @@ impl ComputeHook {
Ok(())
}
}
async fn do_notify_iteration(
/// The compute hook is a destination for notifications about changes to tenant:pageserver
/// mapping. It aggregates updates for the shards in a tenant, and when appropriate reconfigures
/// the compute connection string.
pub(super) struct ComputeHook {
config: Config,
tenants: std::sync::Mutex<HashMap<TenantId, ComputeHookTenant>>,
timelines: std::sync::Mutex<HashMap<TenantTimelineId, ComputeHookTimeline>>,
authorization_header: Option<String>,
// Concurrency limiter, so that we do not overload the cloud control plane when updating
// large numbers of tenants (e.g. when failing over after a node failure)
api_concurrency: tokio::sync::Semaphore,
// This lock is only used in testing enviroments, to serialize calls into neon_local
neon_local_lock: tokio::sync::Mutex<()>,
// We share a client across all notifications to enable connection re-use etc when
// sending large numbers of notifications
client: reqwest::Client,
}
/// Callers may give us a list of these when asking us to send a bulk batch
/// of notifications in the background. This is a 'notification' in the sense of
/// other code notifying us of a shard's status, rather than being the final notification
/// that we send upwards to the control plane for the whole tenant.
pub(crate) struct ShardUpdate<'a> {
pub(crate) tenant_shard_id: TenantShardId,
pub(crate) node_id: NodeId,
pub(crate) stripe_size: ShardStripeSize,
pub(crate) preferred_az: Option<Cow<'a, AvailabilityZone>>,
}
pub(crate) struct SafekeepersUpdate {
pub(crate) tenant_id: TenantId,
pub(crate) timeline_id: TimelineId,
pub(crate) generation: SafekeeperGeneration,
pub(crate) safekeepers: Vec<SafekeeperInfo>,
}
impl ComputeHook {
pub(super) fn new(config: Config) -> anyhow::Result<Self> {
let authorization_header = config
.control_plane_jwt_token
.clone()
.map(|jwt| format!("Bearer {jwt}"));
let mut client = reqwest::ClientBuilder::new().timeout(NOTIFY_REQUEST_TIMEOUT);
for cert in &config.ssl_ca_certs {
client = client.add_root_certificate(cert.clone());
}
let client = client
.build()
.context("Failed to build http client for compute hook")?;
Ok(Self {
tenants: Default::default(),
timelines: Default::default(),
config,
authorization_header,
neon_local_lock: Default::default(),
api_concurrency: tokio::sync::Semaphore::new(API_CONCURRENCY),
client,
})
}
/// For test environments: use neon_local's LocalEnv to update compute
async fn do_notify_local<M: ApiMethod>(&self, req: &M::Request) -> Result<(), NotifyError> {
// neon_local updates are not safe to call concurrently, use a lock to serialize
// all calls to this function
let _locked = self.neon_local_lock.lock().await;
let Some(repo_dir) = self.config.neon_local_repo_dir.as_deref() else {
tracing::warn!(
"neon_local_repo_dir not set, likely a bug in neon_local; skipping compute update"
);
return Ok(());
};
let env = match LocalEnv::load_config(repo_dir) {
Ok(e) => e,
Err(e) => {
tracing::warn!("Couldn't load neon_local config, skipping compute update ({e})");
return Ok(());
}
};
let cplane =
ComputeControlPlane::load(env.clone()).expect("Error loading compute control plane");
M::notify_local(&env, &cplane, req).await
}
async fn do_notify_iteration<Req: serde::Serialize + std::fmt::Debug>(
&self,
url: &String,
reconfigure_request: &ComputeHookNotifyRequest,
reconfigure_request: &Req,
cancel: &CancellationToken,
) -> Result<(), NotifyError> {
let req = self.client.request(reqwest::Method::PUT, url);
@@ -479,9 +651,7 @@ impl ComputeHook {
};
// Treat all 2xx responses as success
if response.status() >= reqwest::StatusCode::OK
&& response.status() < reqwest::StatusCode::MULTIPLE_CHOICES
{
if response.status().is_success() {
if response.status() != reqwest::StatusCode::OK {
// Non-200 2xx response: it doesn't make sense to retry, but this is unexpected, so
// log a warning.
@@ -532,10 +702,10 @@ impl ComputeHook {
}
}
async fn do_notify(
async fn do_notify<R: serde::Serialize + std::fmt::Debug>(
&self,
url: &String,
reconfigure_request: &ComputeHookNotifyRequest,
reconfigure_request: &R,
cancel: &CancellationToken,
) -> Result<(), NotifyError> {
// We hold these semaphore units across all retries, rather than only across each
@@ -567,13 +737,13 @@ impl ComputeHook {
}
/// Synchronous phase: update the per-tenant state for the next intended notification
fn notify_prepare(&self, shard_update: ShardUpdate) -> MaybeSendResult {
let mut state_locked = self.state.lock().unwrap();
fn notify_attach_prepare(&self, shard_update: ShardUpdate) -> MaybeSendNotifyAttachResult {
let mut tenants_locked = self.tenants.lock().unwrap();
use std::collections::hash_map::Entry;
let tenant_shard_id = shard_update.tenant_shard_id;
let tenant = match state_locked.entry(tenant_shard_id.tenant_id) {
let tenant = match tenants_locked.entry(tenant_shard_id.tenant_id) {
Entry::Vacant(e) => {
let ShardUpdate {
tenant_shard_id,
@@ -597,10 +767,37 @@ impl ComputeHook {
tenant.maybe_send(tenant_shard_id.tenant_id, None)
}
async fn notify_execute(
fn notify_safekeepers_prepare(
&self,
maybe_send_result: MaybeSendResult,
tenant_shard_id: TenantShardId,
safekeepers_update: SafekeepersUpdate,
) -> MaybeSendNotifySafekeepersResult {
let mut timelines_locked = self.timelines.lock().unwrap();
let ttid = TenantTimelineId {
tenant_id: safekeepers_update.tenant_id,
timeline_id: safekeepers_update.timeline_id,
};
use std::collections::hash_map::Entry;
let timeline = match timelines_locked.entry(ttid) {
Entry::Vacant(e) => e.insert(ComputeHookTimeline::new(
safekeepers_update.generation,
safekeepers_update.safekeepers,
)),
Entry::Occupied(e) => {
let timeline = e.into_mut();
timeline.update(safekeepers_update);
timeline
}
};
timeline.maybe_send(ttid, None)
}
async fn notify_execute<M: ApiMethod>(
&self,
state: &std::sync::Mutex<HashMap<M::Key, M>>,
maybe_send_result: MaybeSendResult<M::Request, M::Key>,
cancel: &CancellationToken,
) -> Result<(), NotifyError> {
// Process result: we may get an update to send, or we may have to wait for a lock
@@ -609,7 +806,7 @@ impl ComputeHook {
MaybeSendResult::Noop => {
return Ok(());
}
MaybeSendResult::AwaitLock(send_lock) => {
MaybeSendResult::AwaitLock((key, send_lock)) => {
let send_locked = tokio::select! {
guard = send_lock.lock_owned() => {guard},
_ = cancel.cancelled() => {
@@ -620,11 +817,11 @@ impl ComputeHook {
// Lock order: maybe_send is called within the `[Self::state]` lock, and takes the send lock, but here
// we have acquired the send lock and take `[Self::state]` lock. This is safe because maybe_send only uses
// try_lock.
let state_locked = self.state.lock().unwrap();
let Some(tenant) = state_locked.get(&tenant_shard_id.tenant_id) else {
let state_locked = state.lock().unwrap();
let Some(resource_state) = state_locked.get(&key) else {
return Ok(());
};
match tenant.maybe_send(tenant_shard_id.tenant_id, Some(send_locked)) {
match resource_state.maybe_send(key, Some(send_locked)) {
MaybeSendResult::AwaitLock(_) => {
unreachable!("We supplied lock guard")
}
@@ -643,14 +840,18 @@ impl ComputeHook {
.control_plane_url
.as_ref()
.map(|control_plane_url| {
format!("{}/notify-attach", control_plane_url.trim_end_matches('/'))
format!(
"{}/{}",
control_plane_url.trim_end_matches('/'),
M::API_PATH
)
});
// We validate this at startup
let notify_url = compute_hook_url.as_ref().unwrap();
self.do_notify(notify_url, &request, cancel).await
} else {
self.do_notify_local(&request).await.map_err(|e| {
self.do_notify_local::<M>(&request).await.map_err(|e| {
// This path is for testing only, so munge the error into our prod-style error type.
tracing::error!("neon_local notification hook failed: {e}");
NotifyError::Fatal(StatusCode::INTERNAL_SERVER_ERROR)
@@ -686,7 +887,7 @@ impl ComputeHook {
/// Infallible synchronous fire-and-forget version of notify(), that sends its results to
/// a channel. Something should consume the channel and arrange to try notifying again
/// if something failed.
pub(super) fn notify_background(
pub(super) fn notify_attach_background(
self: &Arc<Self>,
notifications: Vec<ShardUpdate>,
result_tx: tokio::sync::mpsc::Sender<Result<(), (TenantShardId, NotifyError)>>,
@@ -695,7 +896,7 @@ impl ComputeHook {
let mut maybe_sends = Vec::new();
for shard_update in notifications {
let tenant_shard_id = shard_update.tenant_shard_id;
let maybe_send_result = self.notify_prepare(shard_update);
let maybe_send_result = self.notify_attach_prepare(shard_update);
maybe_sends.push((tenant_shard_id, maybe_send_result))
}
@@ -714,10 +915,10 @@ impl ComputeHook {
async move {
this
.notify_execute(maybe_send_result, tenant_shard_id, &cancel)
.notify_execute(&this.tenants, maybe_send_result, &cancel)
.await.map_err(|e| (tenant_shard_id, e))
}.instrument(info_span!(
"notify_background", tenant_id=%tenant_shard_id.tenant_id, shard_id=%tenant_shard_id.shard_slug()
"notify_attach_background", tenant_id=%tenant_shard_id.tenant_id, shard_id=%tenant_shard_id.shard_slug()
))
})
.buffered(API_CONCURRENCY);
@@ -760,14 +961,23 @@ impl ComputeHook {
/// ensuring that they eventually call again to ensure that the compute is eventually notified of
/// the proper pageserver nodes for a tenant.
#[tracing::instrument(skip_all, fields(tenant_id=%shard_update.tenant_shard_id.tenant_id, shard_id=%shard_update.tenant_shard_id.shard_slug(), node_id))]
pub(super) async fn notify<'a>(
pub(super) async fn notify_attach<'a>(
&self,
shard_update: ShardUpdate<'a>,
cancel: &CancellationToken,
) -> Result<(), NotifyError> {
let tenant_shard_id = shard_update.tenant_shard_id;
let maybe_send_result = self.notify_prepare(shard_update);
self.notify_execute(maybe_send_result, tenant_shard_id, cancel)
let maybe_send_result = self.notify_attach_prepare(shard_update);
self.notify_execute(&self.tenants, maybe_send_result, cancel)
.await
}
pub(super) async fn notify_safekeepers(
&self,
safekeepers_update: SafekeepersUpdate,
cancel: &CancellationToken,
) -> Result<(), NotifyError> {
let maybe_send_result = self.notify_safekeepers_prepare(safekeepers_update);
self.notify_execute(&self.timelines, maybe_send_result, cancel)
.await
}
@@ -783,8 +993,8 @@ impl ComputeHook {
) {
use std::collections::hash_map::Entry;
let mut state_locked = self.state.lock().unwrap();
match state_locked.entry(tenant_shard_id.tenant_id) {
let mut tenants_locked = self.tenants.lock().unwrap();
match tenants_locked.entry(tenant_shard_id.tenant_id) {
Entry::Vacant(_) => {
// This is a valid but niche case, where the tenant was previously attached
// as a Secondary location and then detached, so has no previously notified

View File

@@ -22,7 +22,7 @@ use pageserver_api::controller_api::{
MetadataHealthListUnhealthyResponse, MetadataHealthUpdateRequest, MetadataHealthUpdateResponse,
NodeAvailability, NodeConfigureRequest, NodeRegisterRequest, SafekeeperSchedulingPolicyRequest,
ShardsPreferredAzsRequest, TenantCreateRequest, TenantPolicyRequest, TenantShardMigrateRequest,
TimelineImportRequest,
TimelineImportRequest, TimelineSafekeeperMigrateRequest,
};
use pageserver_api::models::{
DetachBehavior, LsnLeaseRequest, TenantConfigPatchRequest, TenantConfigRequest,
@@ -34,6 +34,7 @@ use pageserver_api::upcall_api::{
PutTimelineImportStatusRequest, ReAttachRequest, TimelineImportStatusRequest, ValidateRequest,
};
use pageserver_client::{BlockUnblock, mgmt_api};
use routerify::Middleware;
use tokio_util::sync::CancellationToken;
use tracing::warn;
@@ -635,6 +636,32 @@ async fn handle_tenant_timeline_download_heatmap_layers(
json_response(StatusCode::OK, ())
}
async fn handle_tenant_timeline_safekeeper_migrate(
service: Arc<Service>,
req: Request<Body>,
) -> Result<Response<Body>, ApiError> {
let tenant_id: TenantId = parse_request_param(&req, "tenant_id")?;
check_permissions(&req, Scope::PageServerApi)?;
maybe_rate_limit(&req, tenant_id).await;
let timeline_id: TimelineId = parse_request_param(&req, "timeline_id")?;
let mut req = match maybe_forward(req).await {
ForwardOutcome::Forwarded(res) => {
return res;
}
ForwardOutcome::NotForwarded(req) => req,
};
let migrate_req = json_request::<TimelineSafekeeperMigrateRequest>(&mut req).await?;
service
.tenant_timeline_safekeeper_migrate(tenant_id, timeline_id, migrate_req)
.await?;
json_response(StatusCode::OK, ())
}
async fn handle_tenant_timeline_lsn_lease(
service: Arc<Service>,
req: Request<Body>,
@@ -2458,6 +2485,16 @@ pub fn make_router(
)
},
)
.post(
"/v1/tenant/:tenant_id/timeline/:timeline_id/safekeeper_migrate",
|r| {
tenant_service_handler(
r,
handle_tenant_timeline_safekeeper_migrate,
RequestName("v1_tenant_timeline_safekeeper_migrate"),
)
},
)
// LSN lease passthrough to all shards
.post(
"/v1/tenant/:tenant_id/timeline/:timeline_id/lsn_lease",

View File

@@ -6,9 +6,7 @@ use std::time::Duration;
use anyhow::{Context, anyhow};
use camino::Utf8PathBuf;
#[cfg(feature = "testing")]
use clap::ArgAction;
use clap::Parser;
use clap::{ArgAction, Parser};
use futures::future::OptionFuture;
use http_utils::tls_certs::ReloadingCertificateResolver;
use hyper0::Uri;
@@ -216,14 +214,13 @@ struct Cli {
/// Number of safekeepers to choose for a timeline when creating it.
/// Safekeepers will be choosen from different availability zones.
/// This option exists primarily for testing purposes.
#[arg(long, default_value = "3", value_parser = clap::value_parser!(i64).range(1..))]
timeline_safekeeper_count: i64,
#[arg(long, default_value = "3", value_parser = clap::builder::RangedU64ValueParser::<usize>::new().range(1..))]
timeline_safekeeper_count: usize,
/// When set, actively checks and initiates heatmap downloads/uploads during reconciliation.
/// This speed up migrations by avoiding the default wait for the heatmap download interval.
/// Primarily useful for testing to reduce test execution time.
#[cfg(feature = "testing")]
#[arg(long, default_value = "true", action=ArgAction::Set)]
#[arg(long, default_value = "false", action=ArgAction::Set)]
kick_secondary_downloads: bool,
}
@@ -472,7 +469,6 @@ async fn async_main() -> anyhow::Result<()> {
use_local_compute_notifications: args.use_local_compute_notifications,
timeline_safekeeper_count: args.timeline_safekeeper_count,
posthog_config: posthog_config.clone(),
#[cfg(feature = "testing")]
kick_secondary_downloads: args.kick_secondary_downloads,
};
@@ -560,9 +556,15 @@ async fn async_main() -> anyhow::Result<()> {
let cancel_bg = cancel.clone();
let task = tokio::task::spawn(
async move {
let feature_flag_service = FeatureFlagService::new(service, posthog_config);
let feature_flag_service = Arc::new(feature_flag_service);
feature_flag_service.run(cancel_bg).await
match FeatureFlagService::new(service, posthog_config) {
Ok(feature_flag_service) => {
let feature_flag_service = Arc::new(feature_flag_service);
feature_flag_service.run(cancel_bg).await
}
Err(e) => {
tracing::warn!("Failed to create feature flag service: {}", e);
}
};
}
.instrument(tracing::info_span!("feature_flag_service")),
);

View File

@@ -333,6 +333,7 @@ pub(crate) enum DatabaseErrorLabel {
ConnectionPool,
Logical,
Migration,
Cas,
}
impl DatabaseError {
@@ -343,6 +344,7 @@ impl DatabaseError {
Self::ConnectionPool(_) => DatabaseErrorLabel::ConnectionPool,
Self::Logical(_) => DatabaseErrorLabel::Logical,
Self::Migration(_) => DatabaseErrorLabel::Migration,
Self::Cas(_) => DatabaseErrorLabel::Cas,
}
}
}

View File

@@ -29,6 +29,7 @@ use pageserver_api::shard::{
use rustls::client::WebPkiServerVerifier;
use rustls::client::danger::{ServerCertVerified, ServerCertVerifier};
use rustls::crypto::ring;
use safekeeper_api::membership::SafekeeperGeneration;
use scoped_futures::ScopedBoxFuture;
use serde::{Deserialize, Serialize};
use utils::generation::Generation;
@@ -94,6 +95,8 @@ pub(crate) enum DatabaseError {
Logical(String),
#[error("Migration error: {0}")]
Migration(String),
#[error("CAS error: {0}")]
Cas(String),
}
#[derive(measured::FixedCardinalityLabel, Copy, Clone)]
@@ -126,6 +129,7 @@ pub(crate) enum DatabaseOperation {
UpdateLeader,
SetPreferredAzs,
InsertTimeline,
UpdateTimelineMembership,
GetTimeline,
InsertTimelineReconcile,
RemoveTimelineReconcile,
@@ -1410,6 +1414,56 @@ impl Persistence {
.await
}
/// Update timeline membership configuration in the database.
/// Perform a compare-and-swap (CAS) operation on the timeline's generation.
/// The `new_generation` must be the next (+1) generation after the one in the database.
pub(crate) async fn update_timeline_membership(
&self,
tenant_id: TenantId,
timeline_id: TimelineId,
new_generation: SafekeeperGeneration,
sk_set: &[NodeId],
new_sk_set: Option<&[NodeId]>,
) -> DatabaseResult<()> {
use crate::schema::timelines::dsl;
let prev_generation = new_generation.previous().unwrap();
let tenant_id = &tenant_id;
let timeline_id = &timeline_id;
self.with_measured_conn(DatabaseOperation::UpdateTimelineMembership, move |conn| {
Box::pin(async move {
let updated = diesel::update(dsl::timelines)
.filter(dsl::tenant_id.eq(&tenant_id.to_string()))
.filter(dsl::timeline_id.eq(&timeline_id.to_string()))
.filter(dsl::generation.eq(prev_generation.into_inner() as i32))
.set((
dsl::generation.eq(new_generation.into_inner() as i32),
dsl::sk_set.eq(sk_set.iter().map(|id| id.0 as i64).collect::<Vec<_>>()),
dsl::new_sk_set.eq(new_sk_set
.map(|set| set.iter().map(|id| id.0 as i64).collect::<Vec<_>>())),
))
.execute(conn)
.await?;
match updated {
0 => {
// TODO(diko): It makes sense to select the current generation
// and include it in the error message for better debuggability.
Err(DatabaseError::Cas(
"Failed to update membership configuration".to_string(),
))
}
1 => Ok(()),
_ => Err(DatabaseError::Logical(format!(
"unexpected number of rows ({updated})"
))),
}
})
})
.await
}
/// Load timeline from db. Returns `None` if not present.
pub(crate) async fn get_timeline(
&self,

View File

@@ -65,7 +65,7 @@ pub(super) struct Reconciler {
pub(crate) compute_hook: Arc<ComputeHook>,
/// To avoid stalling if the cloud control plane is unavailable, we may proceed
/// past failures in [`ComputeHook::notify`], but we _must_ remember that we failed
/// past failures in [`ComputeHook::notify_attach`], but we _must_ remember that we failed
/// so that we can set [`crate::tenant_shard::TenantShard::pending_compute_notification`] to ensure a later retry.
pub(crate) compute_notify_failure: bool,
@@ -1023,7 +1023,7 @@ impl Reconciler {
if let Some(node) = &self.intent.attached {
let result = self
.compute_hook
.notify(
.notify_attach(
compute_hook::ShardUpdate {
tenant_shard_id: self.tenant_shard_id,
node_id: node.get_id(),

View File

@@ -2,6 +2,7 @@ use std::time::Duration;
use pageserver_api::controller_api::{SafekeeperDescribeResponse, SkSchedulingPolicy};
use reqwest::StatusCode;
use safekeeper_api::membership::SafekeeperId;
use safekeeper_client::mgmt_api;
use tokio_util::sync::CancellationToken;
use utils::backoff;
@@ -92,6 +93,13 @@ impl Safekeeper {
pub(crate) fn has_https_port(&self) -> bool {
self.listen_https_port.is_some()
}
pub(crate) fn get_safekeeper_id(&self) -> SafekeeperId {
SafekeeperId {
id: self.id,
host: self.skp.host.clone(),
pg_port: self.skp.port as u16,
}
}
/// Perform an operation (which is given a [`SafekeeperClient`]) with retries
#[allow(clippy::too_many_arguments)]
pub(crate) async fn with_client_retries<T, O, F>(

View File

@@ -56,6 +56,10 @@ impl SafekeeperClient {
}
}
pub(crate) fn node_id_label(&self) -> &str {
&self.node_id_label
}
pub(crate) async fn create_timeline(
&self,
req: &TimelineCreateRequest,

View File

@@ -161,6 +161,7 @@ enum TenantOperations {
DropDetached,
DownloadHeatmapLayers,
TimelineLsnLease,
TimelineSafekeeperMigrate,
}
#[derive(Clone, strum_macros::Display)]
@@ -471,12 +472,12 @@ pub struct Config {
/// Number of safekeepers to choose for a timeline when creating it.
/// Safekeepers will be choosen from different availability zones.
pub timeline_safekeeper_count: i64,
pub timeline_safekeeper_count: usize,
/// PostHog integration config
pub posthog_config: Option<PostHogConfig>,
#[cfg(feature = "testing")]
/// When set, actively checks and initiates heatmap downloads/uploads.
pub kick_secondary_downloads: bool,
}
@@ -491,6 +492,7 @@ impl From<DatabaseError> for ApiError {
DatabaseError::Logical(reason) | DatabaseError::Migration(reason) => {
ApiError::InternalServerError(anyhow::anyhow!(reason))
}
DatabaseError::Cas(reason) => ApiError::Conflict(reason),
}
}
}
@@ -876,18 +878,18 @@ impl Service {
// Emit compute hook notifications for all tenants which are already stably attached. Other tenants
// will emit compute hook notifications when they reconcile.
//
// Ordering: our calls to notify_background synchronously establish a relative order for these notifications vs. any later
// Ordering: our calls to notify_attach_background synchronously establish a relative order for these notifications vs. any later
// calls into the ComputeHook for the same tenant: we can leave these to run to completion in the background and any later
// calls will be correctly ordered wrt these.
//
// Concurrency: we call notify_background for all tenants, which will create O(N) tokio tasks, but almost all of them
// Concurrency: we call notify_attach_background for all tenants, which will create O(N) tokio tasks, but almost all of them
// will just wait on the ComputeHook::API_CONCURRENCY semaphore immediately, so very cheap until they get that semaphore
// unit and start doing I/O.
tracing::info!(
"Sending {} compute notifications",
compute_notifications.len()
);
self.compute_hook.notify_background(
self.compute_hook.notify_attach_background(
compute_notifications,
bg_compute_notify_result_tx.clone(),
&self.cancel,
@@ -2582,7 +2584,7 @@ impl Service {
.do_initial_shard_scheduling(
tenant_shard_id,
initial_generation,
&create_req.shard_parameters,
create_req.shard_parameters,
create_req.config.clone(),
placement_policy.clone(),
preferred_az_id.as_ref(),
@@ -2639,7 +2641,7 @@ impl Service {
&self,
tenant_shard_id: TenantShardId,
initial_generation: Option<Generation>,
shard_params: &ShardParameters,
shard_params: ShardParameters,
config: TenantConfig,
placement_policy: PlacementPolicy,
preferred_az_id: Option<&AvailabilityZone>,
@@ -6279,7 +6281,7 @@ impl Service {
for (child_id, child_ps, stripe_size) in child_locations {
if let Err(e) = self
.compute_hook
.notify(
.notify_attach(
compute_hook::ShardUpdate {
tenant_shard_id: child_id,
node_id: child_ps,
@@ -8364,7 +8366,6 @@ impl Service {
"Skipping migration of {tenant_shard_id} to {node} because secondary isn't ready: {progress:?}"
);
#[cfg(feature = "testing")]
if progress.heatmap_mtime.is_none() {
// No heatmap might mean the attached location has never uploaded one, or that
// the secondary download hasn't happened yet. This is relatively unusual in the field,
@@ -8389,7 +8390,6 @@ impl Service {
/// happens on multi-minute timescales in the field, which is fine because optimisation is meant
/// to be a lazy background thing. However, when testing, it is not practical to wait around, so
/// we have this helper to move things along faster.
#[cfg(feature = "testing")]
async fn kick_secondary_download(&self, tenant_shard_id: TenantShardId) {
if !self.config.kick_secondary_downloads {
// No-op if kick_secondary_downloads functionaliuty is not configured

View File

@@ -3,7 +3,7 @@ use std::{sync::Arc, time::Duration};
use futures::StreamExt;
use pageserver_api::config::PostHogConfig;
use pageserver_client::mgmt_api;
use posthog_client_lite::{PostHogClient, PostHogClientConfig};
use posthog_client_lite::PostHogClient;
use reqwest::StatusCode;
use tokio::time::MissedTickBehavior;
use tokio_util::sync::CancellationToken;
@@ -20,20 +20,14 @@ pub struct FeatureFlagService {
const DEFAULT_POSTHOG_REFRESH_INTERVAL: Duration = Duration::from_secs(30);
impl FeatureFlagService {
pub fn new(service: Arc<Service>, config: PostHogConfig) -> Self {
let client = PostHogClient::new(PostHogClientConfig {
project_id: config.project_id.clone(),
server_api_key: config.server_api_key.clone(),
client_api_key: config.client_api_key.clone(),
private_api_url: config.private_api_url.clone(),
public_api_url: config.public_api_url.clone(),
});
Self {
pub fn new(service: Arc<Service>, config: PostHogConfig) -> Result<Self, &'static str> {
let client = PostHogClient::new(config.clone().try_into_posthog_config()?);
Ok(Self {
service,
config,
client,
http_client: reqwest::Client::new(),
}
})
}
async fn refresh(self: Arc<Self>, cancel: CancellationToken) -> Result<(), anyhow::Error> {

View File

@@ -145,7 +145,7 @@ pub(crate) async fn load_schedule_requests(
}
let Some(sk) = safekeepers.get(&other_node_id) else {
tracing::warn!(
"couldnt find safekeeper with pending op id {other_node_id}, not pulling from it"
"couldn't find safekeeper with pending op id {other_node_id}, not pulling from it"
);
return None;
};

View File

@@ -1,25 +1,34 @@
use std::cmp::max;
use std::collections::HashSet;
use std::str::FromStr;
use std::sync::Arc;
use std::time::Duration;
use super::safekeeper_reconciler::ScheduleRequest;
use crate::compute_hook;
use crate::heartbeater::SafekeeperState;
use crate::id_lock_map::trace_shared_lock;
use crate::metrics;
use crate::persistence::{
DatabaseError, SafekeeperTimelineOpKind, TimelinePendingOpPersistence, TimelinePersistence,
};
use crate::safekeeper::Safekeeper;
use crate::safekeeper_client::SafekeeperClient;
use crate::service::TenantOperations;
use crate::timeline_import::TimelineImportFinalizeError;
use anyhow::Context;
use http_utils::error::ApiError;
use pageserver_api::controller_api::{
SafekeeperDescribeResponse, SkSchedulingPolicy, TimelineImportRequest,
TimelineSafekeeperMigrateRequest,
};
use pageserver_api::models::{SafekeeperInfo, SafekeepersInfo, TimelineInfo};
use safekeeper_api::PgVersionId;
use safekeeper_api::membership::{MemberSet, SafekeeperGeneration, SafekeeperId};
use safekeeper_api::membership::{self, MemberSet, SafekeeperGeneration};
use safekeeper_api::models::{
PullTimelineRequest, TimelineMembershipSwitchRequest, TimelineMembershipSwitchResponse,
};
use safekeeper_api::{INITIAL_TERM, Term};
use safekeeper_client::mgmt_api;
use tokio::task::JoinSet;
use tokio_util::sync::CancellationToken;
use utils::id::{NodeId, TenantId, TimelineId};
@@ -36,6 +45,33 @@ pub struct TimelineLocateResponse {
}
impl Service {
fn make_member_set(safekeepers: &[Safekeeper]) -> Result<MemberSet, ApiError> {
let members = safekeepers
.iter()
.map(|sk| sk.get_safekeeper_id())
.collect::<Vec<_>>();
MemberSet::new(members).map_err(ApiError::InternalServerError)
}
fn get_safekeepers(&self, ids: &[i64]) -> Result<Vec<Safekeeper>, ApiError> {
let safekeepers = {
let locked = self.inner.read().unwrap();
locked.safekeepers.clone()
};
ids.iter()
.map(|&id| {
let node_id = NodeId(id as u64);
safekeepers.get(&node_id).cloned().ok_or_else(|| {
ApiError::InternalServerError(anyhow::anyhow!(
"safekeeper {node_id} is not registered"
))
})
})
.collect::<Result<Vec<_>, _>>()
}
/// Timeline creation on safekeepers
///
/// Returns `Ok(left)` if the timeline has been created on a quorum of safekeepers,
@@ -48,35 +84,9 @@ impl Service {
pg_version: PgVersionId,
timeline_persistence: &TimelinePersistence,
) -> Result<Vec<NodeId>, ApiError> {
// If quorum is reached, return if we are outside of a specified timeout
let jwt = self
.config
.safekeeper_jwt_token
.clone()
.map(SecretString::from);
let mut joinset = JoinSet::new();
let safekeepers = self.get_safekeepers(&timeline_persistence.sk_set)?;
// Prepare membership::Configuration from choosen safekeepers.
let safekeepers = {
let locked = self.inner.read().unwrap();
locked.safekeepers.clone()
};
let mut members = Vec::new();
for sk_id in timeline_persistence.sk_set.iter() {
let sk_id = NodeId(*sk_id as u64);
let Some(safekeeper) = safekeepers.get(&sk_id) else {
return Err(ApiError::InternalServerError(anyhow::anyhow!(
"couldn't find entry for safekeeper with id {sk_id}"
)))?;
};
members.push(SafekeeperId {
id: sk_id,
host: safekeeper.skp.host.clone(),
pg_port: safekeeper.skp.port as u16,
});
}
let mset = MemberSet::new(members).map_err(ApiError::InternalServerError)?;
let mset = Self::make_member_set(&safekeepers)?;
let mconf = safekeeper_api::membership::Configuration::new(mset);
let req = safekeeper_api::models::TimelineCreateRequest {
@@ -89,79 +99,150 @@ impl Service {
timeline_id,
wal_seg_size: None,
};
const SK_CREATE_TIMELINE_RECONCILE_TIMEOUT: Duration = Duration::from_secs(30);
for sk in timeline_persistence.sk_set.iter() {
let sk_id = NodeId(*sk as u64);
let safekeepers = safekeepers.clone();
let results = self
.tenant_timeline_safekeeper_op_quorum(
&safekeepers,
move |client| {
let req = req.clone();
async move { client.create_timeline(&req).await }
},
SK_CREATE_TIMELINE_RECONCILE_TIMEOUT,
)
.await?;
Ok(results
.into_iter()
.enumerate()
.filter_map(|(idx, res)| {
if res.is_ok() {
None // Success, don't return this safekeeper
} else {
Some(safekeepers[idx].get_id()) // Failure, return this safekeeper
}
})
.collect::<Vec<_>>())
}
/// Perform an operation on a list of safekeepers in parallel with retries.
///
/// Return the results of the operation on each safekeeper in the input order.
async fn tenant_timeline_safekeeper_op<T, O, F>(
&self,
safekeepers: &[Safekeeper],
op: O,
timeout: Duration,
) -> Result<Vec<mgmt_api::Result<T>>, ApiError>
where
O: FnMut(SafekeeperClient) -> F + Send + 'static,
O: Clone,
F: std::future::Future<Output = mgmt_api::Result<T>> + Send + 'static,
T: Sync + Send + 'static,
{
let jwt = self
.config
.safekeeper_jwt_token
.clone()
.map(SecretString::from);
let mut joinset = JoinSet::new();
for (idx, sk) in safekeepers.iter().enumerate() {
let sk = sk.clone();
let http_client = self.http_client.clone();
let jwt = jwt.clone();
let req = req.clone();
let op = op.clone();
joinset.spawn(async move {
// Unwrap is fine as we already would have returned error above
let sk_p = safekeepers.get(&sk_id).unwrap();
let res = sk_p
let res = sk
.with_client_retries(
|client| {
let req = req.clone();
async move { client.create_timeline(&req).await }
},
op,
&http_client,
&jwt,
3,
3,
SK_CREATE_TIMELINE_RECONCILE_TIMEOUT,
// TODO(diko): This is a wrong timeout.
// It should be scaled to the retry count.
timeout,
&CancellationToken::new(),
)
.await;
(sk_id, sk_p.skp.host.clone(), res)
(idx, res)
});
}
// Initialize results with timeout errors in case we never get a response.
let mut results: Vec<mgmt_api::Result<T>> = safekeepers
.iter()
.map(|_| {
Err(mgmt_api::Error::Timeout(
"safekeeper operation timed out".to_string(),
))
})
.collect();
// After we have built the joinset, we now wait for the tasks to complete,
// but with a specified timeout to make sure we return swiftly, either with
// a failure or success.
let reconcile_deadline = tokio::time::Instant::now() + SK_CREATE_TIMELINE_RECONCILE_TIMEOUT;
let reconcile_deadline = tokio::time::Instant::now() + timeout;
// Wait until all tasks finish or timeout is hit, whichever occurs
// first.
let mut reconcile_results = Vec::new();
let mut result_count = 0;
loop {
if let Ok(res) = tokio::time::timeout_at(reconcile_deadline, joinset.join_next()).await
{
let Some(res) = res else { break };
match res {
Ok(res) => {
Ok((idx, res)) => {
let sk = &safekeepers[idx];
tracing::info!(
"response from safekeeper id:{} at {}: {:?}",
res.0,
res.1,
res.2
sk.get_id(),
sk.skp.host,
// Only print errors, as there is no Debug trait for T.
res.as_ref().map(|_| ()),
);
reconcile_results.push(res);
results[idx] = res;
result_count += 1;
}
Err(join_err) => {
tracing::info!("join_err for task in joinset: {join_err}");
}
}
} else {
tracing::info!(
"timeout for creation call after {} responses",
reconcile_results.len()
);
tracing::info!("timeout for operation call after {result_count} responses",);
break;
}
}
// Now check now if quorum was reached in reconcile_results.
let total_result_count = reconcile_results.len();
let remaining = reconcile_results
.into_iter()
.filter_map(|res| res.2.is_err().then_some(res.0))
.collect::<Vec<_>>();
tracing::info!(
"Got {} non-successful responses from initial creation request of total {total_result_count} responses",
remaining.len()
);
let target_sk_count = timeline_persistence.sk_set.len();
Ok(results)
}
/// Perform an operation on a list of safekeepers in parallel with retries,
/// and validates that we reach a quorum of successful responses.
///
/// Return the results of the operation on each safekeeper in the input order.
/// It's guaranteed that at least a quorum of the responses are successful.
async fn tenant_timeline_safekeeper_op_quorum<T, O, F>(
&self,
safekeepers: &[Safekeeper],
op: O,
timeout: Duration,
) -> Result<Vec<mgmt_api::Result<T>>, ApiError>
where
O: FnMut(SafekeeperClient) -> F,
O: Clone + Send + 'static,
F: std::future::Future<Output = mgmt_api::Result<T>> + Send + 'static,
T: Sync + Send + 'static,
{
let results = self
.tenant_timeline_safekeeper_op(safekeepers, op, timeout)
.await?;
// Now check if quorum was reached in results.
let target_sk_count = safekeepers.len();
let quorum_size = match target_sk_count {
0 => {
return Err(ApiError::InternalServerError(anyhow::anyhow!(
@@ -180,7 +261,7 @@ impl Service {
// in order to schedule work to them
tracing::warn!(
"couldn't find at least 3 safekeepers for timeline, found: {:?}",
timeline_persistence.sk_set
target_sk_count
);
return Err(ApiError::InternalServerError(anyhow::anyhow!(
"couldn't find at least 3 safekeepers to put timeline to"
@@ -189,7 +270,7 @@ impl Service {
}
_ => target_sk_count / 2 + 1,
};
let success_count = target_sk_count - remaining.len();
let success_count = results.iter().filter(|res| res.is_ok()).count();
if success_count < quorum_size {
// Failure
return Err(ApiError::InternalServerError(anyhow::anyhow!(
@@ -197,7 +278,7 @@ impl Service {
)));
}
Ok(remaining)
Ok(results)
}
/// Create timeline in controller database and on safekeepers.
@@ -654,13 +735,7 @@ impl Service {
)
});
// Number of safekeepers in different AZs we are looking for
let mut wanted_count = self.config.timeline_safekeeper_count as usize;
// TODO(diko): remove this when `timeline_safekeeper_count` option is in the release
// branch and is specified in tests/neon_local config.
if cfg!(feature = "testing") && all_safekeepers.len() < wanted_count {
// In testing mode, we can have less safekeepers than the config says
wanted_count = max(all_safekeepers.len(), 1);
}
let wanted_count = self.config.timeline_safekeeper_count;
let mut sks = Vec::new();
let mut azs = HashSet::new();
@@ -804,4 +879,486 @@ impl Service {
}
Ok(())
}
/// Call `switch_timeline_membership` on all safekeepers with retries
/// till the quorum of successful responses is reached.
///
/// If min_position is not None, validates that majority of safekeepers
/// reached at least min_position.
///
/// Return responses from safekeepers in the input order.
async fn tenant_timeline_set_membership_quorum(
self: &Arc<Self>,
tenant_id: TenantId,
timeline_id: TimelineId,
safekeepers: &[Safekeeper],
config: &membership::Configuration,
min_position: Option<(Term, Lsn)>,
) -> Result<Vec<mgmt_api::Result<TimelineMembershipSwitchResponse>>, ApiError> {
let req = TimelineMembershipSwitchRequest {
mconf: config.clone(),
};
const SK_SET_MEM_TIMELINE_RECONCILE_TIMEOUT: Duration = Duration::from_secs(30);
let results = self
.tenant_timeline_safekeeper_op_quorum(
safekeepers,
move |client| {
let req = req.clone();
async move {
let mut res = client
.switch_timeline_membership(tenant_id, timeline_id, &req)
.await;
// If min_position is not reached, map the response to an error,
// so it isn't counted toward the quorum.
if let Some(min_position) = min_position {
if let Ok(ok_res) = &res {
if (ok_res.term, ok_res.flush_lsn) < min_position {
// Use Error::Timeout to make this error retriable.
res = Err(mgmt_api::Error::Timeout(
format!(
"safekeeper {} returned position {:?} which is less than minimum required position {:?}",
client.node_id_label(),
(ok_res.term, ok_res.flush_lsn),
min_position
)
));
}
}
}
res
}
},
SK_SET_MEM_TIMELINE_RECONCILE_TIMEOUT,
)
.await?;
for res in results.iter().flatten() {
if res.current_conf.generation > config.generation {
// Antoher switch_membership raced us.
return Err(ApiError::Conflict(format!(
"received configuration with generation {} from safekeeper, but expected {}",
res.current_conf.generation, config.generation
)));
} else if res.current_conf.generation < config.generation {
// Note: should never happen.
// If we get a response, it should be at least the sent generation.
tracing::error!(
"received configuration with generation {} from safekeeper, but expected {}",
res.current_conf.generation,
config.generation
);
return Err(ApiError::InternalServerError(anyhow::anyhow!(
"received configuration with generation {} from safekeeper, but expected {}",
res.current_conf.generation,
config.generation
)));
}
}
Ok(results)
}
/// Pull timeline to to_safekeepers from from_safekeepers with retries.
///
/// Returns Ok(()) only if all the pull_timeline requests were successful.
async fn tenant_timeline_pull_from_peers(
self: &Arc<Self>,
tenant_id: TenantId,
timeline_id: TimelineId,
to_safekeepers: &[Safekeeper],
from_safekeepers: &[Safekeeper],
) -> Result<(), ApiError> {
let http_hosts = from_safekeepers
.iter()
.map(|sk| sk.base_url())
.collect::<Vec<_>>();
tracing::info!(
"pulling timeline to {:?} from {:?}",
to_safekeepers
.iter()
.map(|sk| sk.get_id())
.collect::<Vec<_>>(),
from_safekeepers
.iter()
.map(|sk| sk.get_id())
.collect::<Vec<_>>()
);
// TODO(diko): need to pass mconf/generation with the request
// to properly handle tombstones. Ignore tombstones for now.
// Worst case: we leave a timeline on a safekeeper which is not in the current set.
let req = PullTimelineRequest {
tenant_id,
timeline_id,
http_hosts,
ignore_tombstone: Some(true),
};
const SK_PULL_TIMELINE_RECONCILE_TIMEOUT: Duration = Duration::from_secs(30);
let responses = self
.tenant_timeline_safekeeper_op(
to_safekeepers,
move |client| {
let req = req.clone();
async move { client.pull_timeline(&req).await }
},
SK_PULL_TIMELINE_RECONCILE_TIMEOUT,
)
.await?;
if let Some((idx, err)) = responses
.iter()
.enumerate()
.find_map(|(idx, res)| Some((idx, res.as_ref().err()?)))
{
let sk_id = to_safekeepers[idx].get_id();
return Err(ApiError::InternalServerError(anyhow::anyhow!(
"pull_timeline to {sk_id} failed: {err}",
)));
}
Ok(())
}
/// Exclude a timeline from safekeepers in parallel with retries.
/// If an exclude request is unsuccessful, it will be added to
/// the reconciler, and after that the function will succeed.
async fn tenant_timeline_safekeeper_exclude(
self: &Arc<Self>,
tenant_id: TenantId,
timeline_id: TimelineId,
safekeepers: &[Safekeeper],
config: &membership::Configuration,
) -> Result<(), ApiError> {
let req = TimelineMembershipSwitchRequest {
mconf: config.clone(),
};
const SK_EXCLUDE_TIMELINE_TIMEOUT: Duration = Duration::from_secs(30);
let results = self
.tenant_timeline_safekeeper_op(
safekeepers,
move |client| {
let req = req.clone();
async move { client.exclude_timeline(tenant_id, timeline_id, &req).await }
},
SK_EXCLUDE_TIMELINE_TIMEOUT,
)
.await?;
let mut reconcile_requests = Vec::new();
for (idx, res) in results.iter().enumerate() {
if res.is_err() {
let sk_id = safekeepers[idx].skp.id;
let pending_op = TimelinePendingOpPersistence {
tenant_id: tenant_id.to_string(),
timeline_id: timeline_id.to_string(),
generation: config.generation.into_inner() as i32,
op_kind: SafekeeperTimelineOpKind::Exclude,
sk_id,
};
tracing::info!("writing pending exclude op for sk id {sk_id}");
self.persistence.insert_pending_op(pending_op).await?;
let req = ScheduleRequest {
safekeeper: Box::new(safekeepers[idx].clone()),
host_list: Vec::new(),
tenant_id,
timeline_id: Some(timeline_id),
generation: config.generation.into_inner(),
kind: SafekeeperTimelineOpKind::Exclude,
};
reconcile_requests.push(req);
}
}
if !reconcile_requests.is_empty() {
let locked = self.inner.read().unwrap();
for req in reconcile_requests {
locked.safekeeper_reconcilers.schedule_request(req);
}
}
Ok(())
}
/// Migrate timeline safekeeper set to a new set.
///
/// This function implements an algorithm from RFC-035.
/// <https://github.com/neondatabase/neon/blob/main/docs/rfcs/035-safekeeper-dynamic-membership-change.md>
pub(crate) async fn tenant_timeline_safekeeper_migrate(
self: &Arc<Self>,
tenant_id: TenantId,
timeline_id: TimelineId,
req: TimelineSafekeeperMigrateRequest,
) -> Result<(), ApiError> {
let all_safekeepers = self.inner.read().unwrap().safekeepers.clone();
let new_sk_set = req.new_sk_set;
for sk_id in new_sk_set.iter() {
if !all_safekeepers.contains_key(sk_id) {
return Err(ApiError::BadRequest(anyhow::anyhow!(
"safekeeper {sk_id} does not exist"
)));
}
}
// TODO(diko): per-tenant lock is too wide. Consider introducing per-timeline locks.
let _tenant_lock = trace_shared_lock(
&self.tenant_op_locks,
tenant_id,
TenantOperations::TimelineSafekeeperMigrate,
)
.await;
// 1. Fetch current timeline configuration from the configuration storage.
let timeline = self
.persistence
.get_timeline(tenant_id, timeline_id)
.await?;
let Some(timeline) = timeline else {
return Err(ApiError::NotFound(
anyhow::anyhow!(
"timeline {tenant_id}/{timeline_id} doesn't exist in timelines table"
)
.into(),
));
};
let cur_sk_set = timeline
.sk_set
.iter()
.map(|&id| NodeId(id as u64))
.collect::<Vec<_>>();
tracing::info!(
?cur_sk_set,
?new_sk_set,
"Migrating timeline to new safekeeper set",
);
let mut generation = SafekeeperGeneration::new(timeline.generation as u32);
if let Some(ref presistent_new_sk_set) = timeline.new_sk_set {
// 2. If it is already joint one and new_set is different from desired_set refuse to change.
if presistent_new_sk_set
.iter()
.map(|&id| NodeId(id as u64))
.ne(new_sk_set.iter().cloned())
{
tracing::info!(
?presistent_new_sk_set,
?new_sk_set,
"different new safekeeper set is already set in the database",
);
return Err(ApiError::Conflict(format!(
"the timeline is already migrating to a different safekeeper set: {presistent_new_sk_set:?}"
)));
}
// It it is the same new_sk_set, we can continue the migration (retry).
} else {
// 3. No active migration yet.
// Increment current generation and put desired_set to new_sk_set.
generation = generation.next();
self.persistence
.update_timeline_membership(
tenant_id,
timeline_id,
generation,
&cur_sk_set,
Some(&new_sk_set),
)
.await?;
}
let cur_safekeepers = self.get_safekeepers(&timeline.sk_set)?;
let cur_sk_member_set = Self::make_member_set(&cur_safekeepers)?;
let new_sk_set_i64 = new_sk_set.iter().map(|id| id.0 as i64).collect::<Vec<_>>();
let new_safekeepers = self.get_safekeepers(&new_sk_set_i64)?;
let new_sk_member_set = Self::make_member_set(&new_safekeepers)?;
let joint_config = membership::Configuration {
generation,
members: cur_sk_member_set,
new_members: Some(new_sk_member_set.clone()),
};
// 4. Call PUT configuration on safekeepers from the current set,
// delivering them joint_conf.
// Notify cplane/compute about the membership change BEFORE changing the membership on safekeepers.
// This way the compute will know about new safekeepers from joint_config before we require to
// collect a quorum from them.
self.cplane_notify_safekeepers(tenant_id, timeline_id, &joint_config)
.await?;
let results = self
.tenant_timeline_set_membership_quorum(
tenant_id,
timeline_id,
&cur_safekeepers,
&joint_config,
None, // no min position
)
.await?;
let mut sync_position = (INITIAL_TERM, Lsn::INVALID);
for res in results.into_iter().flatten() {
let sk_position = (res.term, res.flush_lsn);
if sync_position < sk_position {
sync_position = sk_position;
}
}
tracing::info!(
%generation,
?sync_position,
"safekeepers set membership updated",
);
// 5. Initialize timeline on safekeeper(s) from new_sk_set where it doesn't exist yet
// by doing pull_timeline from the majority of the current set.
// Filter out safekeepers which are already in the current set.
let from_ids: HashSet<NodeId> = cur_safekeepers.iter().map(|sk| sk.get_id()).collect();
let pull_to_safekeepers = new_safekeepers
.iter()
.filter(|sk| !from_ids.contains(&sk.get_id()))
.cloned()
.collect::<Vec<_>>();
self.tenant_timeline_pull_from_peers(
tenant_id,
timeline_id,
&pull_to_safekeepers,
&cur_safekeepers,
)
.await?;
// 6. Call POST bump_term(sync_term) on safekeepers from the new set. Success on majority is enough.
// TODO(diko): do we need to bump timeline term?
// 7. Repeatedly call PUT configuration on safekeepers from the new set,
// delivering them joint_conf and collecting their positions.
tracing::info!(?sync_position, "waiting for safekeepers to sync position");
self.tenant_timeline_set_membership_quorum(
tenant_id,
timeline_id,
&new_safekeepers,
&joint_config,
Some(sync_position),
)
.await?;
// 8. Create new_conf: Configuration incrementing joint_conf generation and
// having new safekeeper set as sk_set and None new_sk_set.
let generation = generation.next();
let new_conf = membership::Configuration {
generation,
members: new_sk_member_set,
new_members: None,
};
self.persistence
.update_timeline_membership(tenant_id, timeline_id, generation, &new_sk_set, None)
.await?;
// TODO(diko): at this point we have already updated the timeline in the database,
// but we still need to notify safekeepers and cplane about the new configuration,
// and put delition of the timeline from the old safekeepers into the reconciler.
// Ideally it should be done atomically, but now it's not.
// Worst case: the timeline is not deleted from old safekeepers,
// the compute may require both quorums till the migration is retried and completed.
self.tenant_timeline_set_membership_quorum(
tenant_id,
timeline_id,
&new_safekeepers,
&new_conf,
None, // no min position
)
.await?;
let new_ids: HashSet<NodeId> = new_safekeepers.iter().map(|sk| sk.get_id()).collect();
let exclude_safekeepers = cur_safekeepers
.into_iter()
.filter(|sk| !new_ids.contains(&sk.get_id()))
.collect::<Vec<_>>();
self.tenant_timeline_safekeeper_exclude(
tenant_id,
timeline_id,
&exclude_safekeepers,
&new_conf,
)
.await?;
// Notify cplane/compute about the membership change AFTER changing the membership on safekeepers.
// This way the compute will stop talking to excluded safekeepers only after we stop requiring to
// collect a quorum from them.
self.cplane_notify_safekeepers(tenant_id, timeline_id, &new_conf)
.await?;
Ok(())
}
/// Notify cplane about safekeeper membership change.
/// The cplane will receive a joint set of safekeepers as a safekeeper list.
async fn cplane_notify_safekeepers(
&self,
tenant_id: TenantId,
timeline_id: TimelineId,
mconf: &membership::Configuration,
) -> Result<(), ApiError> {
let mut safekeepers = Vec::new();
let mut ids: HashSet<_> = HashSet::new();
for member in mconf
.members
.m
.iter()
.chain(mconf.new_members.iter().flat_map(|m| m.m.iter()))
{
if ids.insert(member.id) {
safekeepers.push(compute_hook::SafekeeperInfo {
id: member.id,
hostname: Some(member.host.clone()),
});
}
}
self.compute_hook
.notify_safekeepers(
compute_hook::SafekeepersUpdate {
tenant_id,
timeline_id,
generation: mconf.generation,
safekeepers,
},
&self.cancel,
)
.await
.map_err(|err| {
ApiError::InternalServerError(anyhow::anyhow!(
"failed to notify cplane about safekeeper membership change: {err}"
))
})
}
}

View File

@@ -453,7 +453,7 @@ class NeonEnvBuilder:
pageserver_get_vectored_concurrent_io: str | None = None,
pageserver_tracing_config: PageserverTracingConfig | None = None,
pageserver_import_config: PageserverImportConfig | None = None,
storcon_kick_secondary_downloads: bool | None = None,
storcon_kick_secondary_downloads: bool | None = True,
):
self.repo_dir = repo_dir
self.rust_log_override = rust_log_override
@@ -1215,6 +1215,13 @@ class NeonEnv:
storage_controller_config = storage_controller_config or {}
storage_controller_config["use_https_safekeeper_api"] = True
# TODO(diko): uncomment when timeline_safekeeper_count option is in the release branch,
# so the compat tests will not fail bacause of it presence.
# if config.num_safekeepers < 3:
# storage_controller_config = storage_controller_config or {}
# if "timeline_safekeeper_count" not in storage_controller_config:
# storage_controller_config["timeline_safekeeper_count"] = config.num_safekeepers
if storage_controller_config is not None:
cfg["storage_controller"] = storage_controller_config
@@ -2226,6 +2233,21 @@ class NeonStorageController(MetricsGetter, LogUtils):
response.raise_for_status()
log.info(f"timeline_create success: {response.json()}")
def migrate_safekeepers(
self,
tenant_id: TenantId,
timeline_id: TimelineId,
new_sk_set: list[int],
):
response = self.request(
"POST",
f"{self.api}/v1/tenant/{tenant_id}/timeline/{timeline_id}/safekeeper_migrate",
json={"new_sk_set": new_sk_set},
headers=self.headers(TokenScope.PAGE_SERVER_API),
)
response.raise_for_status()
log.info(f"migrate_safekeepers success: {response.json()}")
def locate(self, tenant_id: TenantId) -> list[dict[str, Any]]:
"""
:return: list of {"shard_id": "", "node_id": int, "listen_pg_addr": str, "listen_pg_port": int, "listen_http_addr": str, "listen_http_port": int}

Some files were not shown because too many files have changed in this diff Show More