Merge pull request #2170 from neondatabase/main (Release 2022-07-28)

Release 2022-07-28
This commit is contained in:
Sergey Melnikov
2022-07-28 14:16:52 +03:00
committed by GitHub
76 changed files with 4703 additions and 3517 deletions

View File

@@ -1,369 +0,0 @@
version: 2.1
executors:
neon-xlarge-executor:
resource_class: xlarge
docker:
# NB: when changed, do not forget to update rust image tag in all Dockerfiles
- image: neondatabase/rust:1.58
neon-executor:
docker:
- image: neondatabase/rust:1.58
jobs:
# A job to build postgres
build-postgres:
executor: neon-xlarge-executor
parameters:
build_type:
type: enum
enum: ["debug", "release"]
environment:
BUILD_TYPE: << parameters.build_type >>
steps:
# Checkout the git repo (circleci doesn't have a flag to enable submodules here)
- checkout
# Grab the postgres git revision to build a cache key.
# Append makefile as it could change the way postgres is built.
# Note this works even though the submodule hasn't been checkout out yet.
- run:
name: Get postgres cache key
command: |
git rev-parse HEAD:vendor/postgres > /tmp/cache-key-postgres
cat Makefile >> /tmp/cache-key-postgres
- restore_cache:
name: Restore postgres cache
keys:
# Restore ONLY if the rev key matches exactly
- v05-postgres-cache-<< parameters.build_type >>-{{ checksum "/tmp/cache-key-postgres" }}
# Build postgres if the restore_cache didn't find a build.
# `make` can't figure out whether the cache is valid, since
# it only compares file timestamps.
- run:
name: build postgres
command: |
if [ ! -e tmp_install/bin/postgres ]; then
# "depth 1" saves some time by not cloning the whole repo
git submodule update --init --depth 1
# bail out on any warnings
COPT='-Werror' mold -run make postgres -j$(nproc)
fi
- save_cache:
name: Save postgres cache
key: v05-postgres-cache-<< parameters.build_type >>-{{ checksum "/tmp/cache-key-postgres" }}
paths:
- tmp_install
# A job to build Neon rust code
build-neon:
executor: neon-xlarge-executor
parameters:
build_type:
type: enum
enum: ["debug", "release"]
environment:
BUILD_TYPE: << parameters.build_type >>
steps:
# Checkout the git repo (without submodules)
- checkout
# Grab the postgres git revision to build a cache key.
# Append makefile as it could change the way postgres is built.
# Note this works even though the submodule hasn't been checkout out yet.
- run:
name: Get postgres cache key
command: |
git rev-parse HEAD:vendor/postgres > /tmp/cache-key-postgres
cat Makefile >> /tmp/cache-key-postgres
- restore_cache:
name: Restore postgres cache
keys:
# Restore ONLY if the rev key matches exactly
- v05-postgres-cache-<< parameters.build_type >>-{{ checksum "/tmp/cache-key-postgres" }}
- restore_cache:
name: Restore rust cache
keys:
# Require an exact match. While an out of date cache might speed up the build,
# there's no way to clean out old packages, so the cache grows every time something
# changes.
- v05-rust-cache-deps-<< parameters.build_type >>-{{ checksum "Cargo.lock" }}
# Build the rust code, including test binaries
- run:
name: Rust build << parameters.build_type >>
command: |
if [[ $BUILD_TYPE == "debug" ]]; then
CARGO_FLAGS=
elif [[ $BUILD_TYPE == "release" ]]; then
CARGO_FLAGS="--release --features profiling"
fi
export CARGO_INCREMENTAL=0
export CACHEPOT_BUCKET=zenith-rust-cachepot
export RUSTC_WRAPPER=""
export AWS_ACCESS_KEY_ID="${CACHEPOT_AWS_ACCESS_KEY_ID}"
export AWS_SECRET_ACCESS_KEY="${CACHEPOT_AWS_SECRET_ACCESS_KEY}"
mold -run cargo build $CARGO_FLAGS --features failpoints --bins --tests
cachepot -s
- save_cache:
name: Save rust cache
key: v05-rust-cache-deps-<< parameters.build_type >>-{{ checksum "Cargo.lock" }}
paths:
- ~/.cargo/registry
- ~/.cargo/git
- target
# Run rust unit tests
- run:
name: cargo test
command: |
if [[ $BUILD_TYPE == "debug" ]]; then
CARGO_FLAGS=
elif [[ $BUILD_TYPE == "release" ]]; then
CARGO_FLAGS=--release
fi
cargo test $CARGO_FLAGS
# Install the rust binaries, for use by test jobs
- run:
name: Install rust binaries
command: |
binaries=$(
cargo metadata --format-version=1 --no-deps |
jq -r '.packages[].targets[] | select(.kind | index("bin")) | .name'
)
mkdir -p /tmp/zenith/bin
mkdir -p /tmp/zenith/test_bin
mkdir -p /tmp/zenith/etc
# Install target binaries
for bin in $binaries; do
SRC=target/$BUILD_TYPE/$bin
DST=/tmp/zenith/bin/$bin
cp $SRC $DST
done
# Install the postgres binaries, for use by test jobs
- run:
name: Install postgres binaries
command: |
cp -a tmp_install /tmp/zenith/pg_install
# Save rust binaries for other jobs in the workflow
- persist_to_workspace:
root: /tmp/zenith
paths:
- "*"
check-codestyle-python:
executor: neon-executor
steps:
- checkout
- restore_cache:
keys:
- v2-python-deps-{{ checksum "poetry.lock" }}
- run:
name: Install deps
command: ./scripts/pysync
- save_cache:
key: v2-python-deps-{{ checksum "poetry.lock" }}
paths:
- /home/circleci/.cache/pypoetry/virtualenvs
- run:
name: Print versions
when: always
command: |
poetry run python --version
poetry show
- run:
name: Run yapf to ensure code format
when: always
command: poetry run yapf --recursive --diff .
- run:
name: Run mypy to check types
when: always
command: poetry run mypy .
run-pytest:
executor: neon-executor
parameters:
# pytest args to specify the tests to run.
#
# This can be a test file name, e.g. 'test_pgbench.py, or a subdirectory,
# or '-k foobar' to run tests containing string 'foobar'. See pytest man page
# section SPECIFYING TESTS / SELECTING TESTS for details.
#
# Select the type of Rust build. Must be "release" or "debug".
build_type:
type: string
default: "debug"
# This parameter is required, to prevent the mistake of running all tests in one job.
test_selection:
type: string
default: ""
# Arbitrary parameters to pytest. For example "-s" to prevent capturing stdout/stderr
extra_params:
type: string
default: ""
needs_postgres_source:
type: boolean
default: false
run_in_parallel:
type: boolean
default: true
save_perf_report:
type: boolean
default: false
environment:
BUILD_TYPE: << parameters.build_type >>
steps:
- attach_workspace:
at: /tmp/zenith
- checkout
- when:
condition: << parameters.needs_postgres_source >>
steps:
- run: git submodule update --init --depth 1
- restore_cache:
keys:
- v2-python-deps-{{ checksum "poetry.lock" }}
- run:
name: Install deps
command: ./scripts/pysync
- save_cache:
key: v2-python-deps-{{ checksum "poetry.lock" }}
paths:
- /home/circleci/.cache/pypoetry/virtualenvs
- run:
name: Run pytest
# pytest doesn't output test logs in real time, so CI job may fail with
# `Too long with no output` error, if a test is running for a long time.
# In that case, tests should have internal timeouts that are less than
# no_output_timeout, specified here.
no_output_timeout: 10m
environment:
- NEON_BIN: /tmp/zenith/bin
- POSTGRES_DISTRIB_DIR: /tmp/zenith/pg_install
- TEST_OUTPUT: /tmp/test_output
# this variable will be embedded in perf test report
# and is needed to distinguish different environments
- PLATFORM: zenith-local-ci
command: |
PERF_REPORT_DIR="$(realpath test_runner/perf-report-local)"
rm -rf $PERF_REPORT_DIR
TEST_SELECTION="test_runner/<< parameters.test_selection >>"
EXTRA_PARAMS="<< parameters.extra_params >>"
if [ -z "$TEST_SELECTION" ]; then
echo "test_selection must be set"
exit 1
fi
if << parameters.run_in_parallel >>; then
EXTRA_PARAMS="-n4 $EXTRA_PARAMS"
fi
if << parameters.save_perf_report >>; then
if [[ $CIRCLE_BRANCH == "main" ]]; then
mkdir -p "$PERF_REPORT_DIR"
EXTRA_PARAMS="--out-dir $PERF_REPORT_DIR $EXTRA_PARAMS"
fi
fi
export GITHUB_SHA=$CIRCLE_SHA1
# Run the tests.
#
# The junit.xml file allows CircleCI to display more fine-grained test information
# in its "Tests" tab in the results page.
# --verbose prints name of each test (helpful when there are
# multiple tests in one file)
# -rA prints summary in the end
# -n4 uses four processes to run tests via pytest-xdist
# -s is not used to prevent pytest from capturing output, because tests are running
# in parallel and logs are mixed between different tests
./scripts/pytest \
--junitxml=$TEST_OUTPUT/junit.xml \
--tb=short \
--verbose \
-m "not remote_cluster" \
-rA $TEST_SELECTION $EXTRA_PARAMS
if << parameters.save_perf_report >>; then
if [[ $CIRCLE_BRANCH == "main" ]]; then
export REPORT_FROM="$PERF_REPORT_DIR"
export REPORT_TO=local
scripts/generate_and_push_perf_report.sh
fi
fi
- run:
# CircleCI artifacts are preserved one file at a time, so skipping
# this step isn't a good idea. If you want to extract the
# pageserver state, perhaps a tarball would be a better idea.
name: Delete all data but logs
when: always
command: |
du -sh /tmp/test_output/*
find /tmp/test_output -type f ! -name "*.log" ! -name "regression.diffs" ! -name "junit.xml" ! -name "*.filediff" ! -name "*.stdout" ! -name "*.stderr" ! -name "flamegraph.svg" ! -name "*.metrics" -delete
du -sh /tmp/test_output/*
- store_artifacts:
path: /tmp/test_output
# The store_test_results step tells CircleCI where to find the junit.xml file.
- store_test_results:
path: /tmp/test_output
# Save data (if any)
- persist_to_workspace:
root: /tmp/zenith
paths:
- "*"
workflows:
build_and_test:
jobs:
- check-codestyle-python
- build-postgres:
name: build-postgres-<< matrix.build_type >>
matrix:
parameters:
build_type: ["debug", "release"]
- build-neon:
name: build-neon-<< matrix.build_type >>
matrix:
parameters:
build_type: ["debug", "release"]
requires:
- build-postgres-<< matrix.build_type >>
- run-pytest:
name: pg_regress-tests-<< matrix.build_type >>
matrix:
parameters:
build_type: ["debug", "release"]
test_selection: batch_pg_regress
needs_postgres_source: true
requires:
- build-neon-<< matrix.build_type >>
- run-pytest:
name: other-tests-<< matrix.build_type >>
matrix:
parameters:
build_type: ["debug", "release"]
test_selection: batch_others
requires:
- build-neon-<< matrix.build_type >>
- run-pytest:
name: benchmarks
context: PERF_TEST_RESULT_CONNSTR
build_type: release
test_selection: performance
run_in_parallel: false
save_perf_report: true
requires:
- build-neon-release

View File

@@ -37,26 +37,13 @@ runs:
name: neon-${{ runner.os }}-${{ inputs.build_type }}-${{ inputs.rust_toolchain }}-artifact
path: ./neon-artifact/
- name: Get Postgres artifact for restoration
uses: actions/download-artifact@v3
with:
name: postgres-${{ runner.os }}-${{ inputs.build_type }}-artifact
path: ./pg-artifact/
- name: Extract Neon artifact
shell: bash -ex {0}
run: |
mkdir -p /tmp/neon/
tar -xf ./neon-artifact/neon.tgz -C /tmp/neon/
tar -xf ./neon-artifact/neon.tar.zst -C /tmp/neon/
rm -rf ./neon-artifact/
- name: Extract Postgres artifact
shell: bash -ex {0}
run: |
mkdir -p /tmp/neon/tmp_install
tar -xf ./pg-artifact/pg.tgz -C /tmp/neon/tmp_install
rm -rf ./pg-artifact/
- name: Checkout
if: inputs.needs_postgres_source == 'true'
uses: actions/checkout@v3
@@ -78,7 +65,7 @@ runs:
- name: Run pytest
env:
NEON_BIN: /tmp/neon/bin
POSTGRES_DISTRIB_DIR: /tmp/neon/tmp_install
POSTGRES_DISTRIB_DIR: /tmp/neon/pg_install
TEST_OUTPUT: /tmp/test_output
# this variable will be embedded in perf test report
# and is needed to distinguish different environments
@@ -112,7 +99,7 @@ runs:
# Run the tests.
#
# The junit.xml file allows CircleCI to display more fine-grained test information
# The junit.xml file allows CI tools to display more fine-grained test information
# in its "Tests" tab in the results page.
# --verbose prints name of each test (helpful when there are
# multiple tests in one file)

View File

@@ -17,4 +17,4 @@ env_name = prod-1
console_mgmt_base_url = http://console-release.local
bucket_name = zenith-storage-oregon
bucket_region = us-west-2
etcd_endpoints = etcd-release.local:2379
etcd_endpoints = zenith-1-etcd.local:2379

View File

@@ -12,10 +12,9 @@ cat <<EOF | tee /tmp/payload
"version": 1,
"host": "${HOST}",
"port": 6500,
"http_port": 7676,
"region_id": {{ console_region_id }},
"instance_id": "${INSTANCE_ID}",
"http_host": "${HOST}",
"http_port": 7676
"instance_id": "${INSTANCE_ID}"
}
EOF

View File

@@ -17,4 +17,4 @@ env_name = us-stage
console_mgmt_base_url = http://console-staging.local
bucket_name = zenith-staging-storage-us-east-1
bucket_region = us-east-1
etcd_endpoints = etcd-staging.local:2379
etcd_endpoints = zenith-us-stage-etcd.local:2379

View File

@@ -11,7 +11,7 @@ on:
# │ │ ┌───────────── day of the month (1 - 31)
# │ │ │ ┌───────────── month (1 - 12 or JAN-DEC)
# │ │ │ │ ┌───────────── day of the week (0 - 6 or SUN-SAT)
- cron: '36 7 * * *' # run once a day, timezone is utc
- cron: '36 4 * * *' # run once a day, timezone is utc
workflow_dispatch: # adds ability to run this manually

View File

@@ -21,7 +21,7 @@ env:
COPT: '-Werror'
jobs:
build-postgres:
build-neon:
runs-on: [ self-hosted, Linux, k8s-runner ]
strategy:
fail-fast: false
@@ -31,6 +31,7 @@ jobs:
env:
BUILD_TYPE: ${{ matrix.build_type }}
steps:
- name: Checkout
uses: actions/checkout@v3
@@ -42,58 +43,28 @@ jobs:
id: pg_ver
run: echo ::set-output name=pg_rev::$(git rev-parse HEAD:vendor/postgres)
- name: Cache postgres build
id: cache_pg
uses: actions/cache@v3
with:
path: tmp_install/
key: v1-${{ runner.os }}-${{ matrix.build_type }}-pg-${{ steps.pg_ver.outputs.pg_rev }}-${{ hashFiles('Makefile') }}
- name: Build postgres
if: steps.cache_pg.outputs.cache-hit != 'true'
run: mold -run make postgres -j$(nproc)
# actions/cache@v3 does not allow concurrently using the same cache across job steps, so use a separate cache
- name: Prepare postgres artifact
run: tar -C tmp_install/ -czf ./pg.tgz .
- name: Upload postgres artifact
uses: actions/upload-artifact@v3
with:
retention-days: 7
if-no-files-found: error
name: postgres-${{ runner.os }}-${{ matrix.build_type }}-artifact
path: ./pg.tgz
build-neon:
runs-on: [ self-hosted, Linux, k8s-runner ]
needs: [ build-postgres ]
strategy:
fail-fast: false
matrix:
build_type: [ debug, release ]
rust_toolchain: [ 1.58 ]
env:
BUILD_TYPE: ${{ matrix.build_type }}
steps:
- name: Checkout
uses: actions/checkout@v3
with:
submodules: true
fetch-depth: 1
- name: Get postgres artifact for restoration
uses: actions/download-artifact@v3
with:
name: postgres-${{ runner.os }}-${{ matrix.build_type }}-artifact
path: ./postgres-artifact/
- name: Extract postgres artifact
# Set some environment variables used by all the steps.
#
# CARGO_FLAGS is extra options to pass to "cargo build", "cargo test" etc.
# It also includes --features, if any
#
# CARGO_FEATURES is passed to "cargo metadata". It is separate from CARGO_FLAGS,
# because "cargo metadata" doesn't accept --release or --debug options
#
- name: Set env variables
run: |
mkdir ./tmp_install/
tar -xf ./postgres-artifact/pg.tgz -C ./tmp_install/
rm -rf ./postgres-artifact/
if [[ $BUILD_TYPE == "debug" ]]; then
cov_prefix="scripts/coverage --profraw-prefix=$GITHUB_JOB --dir=/tmp/coverage run"
CARGO_FEATURES=""
CARGO_FLAGS=""
elif [[ $BUILD_TYPE == "release" ]]; then
cov_prefix=""
CARGO_FEATURES="--features profiling"
CARGO_FLAGS="--release $CARGO_FEATURES"
fi
echo "cov_prefix=${cov_prefix}" >> $GITHUB_ENV
echo "CARGO_FEATURES=${CARGO_FEATURES}" >> $GITHUB_ENV
echo "CARGO_FLAGS=${CARGO_FLAGS}" >> $GITHUB_ENV
# Don't include the ~/.cargo/registry/src directory. It contains just
# uncompressed versions of the crates in ~/.cargo/registry/cache
@@ -110,59 +81,36 @@ jobs:
target/
# Fall back to older versions of the key, if no cache for current Cargo.lock was found
key: |
v2-${{ runner.os }}-${{ matrix.build_type }}-cargo-${{ matrix.rust_toolchain }}-${{ hashFiles('Cargo.lock') }}
v2-${{ runner.os }}-${{ matrix.build_type }}-cargo-${{ matrix.rust_toolchain }}-
v3-${{ runner.os }}-${{ matrix.build_type }}-cargo-${{ matrix.rust_toolchain }}-${{ hashFiles('Cargo.lock') }}
v3-${{ runner.os }}-${{ matrix.build_type }}-cargo-${{ matrix.rust_toolchain }}-
- name: Cache postgres build
id: cache_pg
uses: actions/cache@v3
with:
path: tmp_install/
key: v1-${{ runner.os }}-${{ matrix.build_type }}-pg-${{ steps.pg_ver.outputs.pg_rev }}-${{ hashFiles('Makefile') }}
- name: Build postgres
if: steps.cache_pg.outputs.cache-hit != 'true'
run: mold -run make postgres -j$(nproc)
- name: Run cargo build
run: |
if [[ $BUILD_TYPE == "debug" ]]; then
cov_prefix=(scripts/coverage "--profraw-prefix=$GITHUB_JOB" --dir=/tmp/coverage run)
CARGO_FLAGS=
elif [[ $BUILD_TYPE == "release" ]]; then
cov_prefix=()
CARGO_FLAGS="--release --features profiling"
fi
"${cov_prefix[@]}" mold -run cargo build $CARGO_FLAGS --features failpoints --bins --tests
${cov_prefix} mold -run cargo build $CARGO_FLAGS --features failpoints --bins --tests
- name: Run cargo test
run: |
if [[ $BUILD_TYPE == "debug" ]]; then
cov_prefix=(scripts/coverage "--profraw-prefix=$GITHUB_JOB" --dir=/tmp/coverage run)
CARGO_FLAGS=
elif [[ $BUILD_TYPE == "release" ]]; then
cov_prefix=()
CARGO_FLAGS=--release
fi
"${cov_prefix[@]}" cargo test $CARGO_FLAGS
${cov_prefix} cargo test $CARGO_FLAGS
- name: Install rust binaries
run: |
if [[ $BUILD_TYPE == "debug" ]]; then
cov_prefix=(scripts/coverage "--profraw-prefix=$GITHUB_JOB" --dir=/tmp/coverage run)
elif [[ $BUILD_TYPE == "release" ]]; then
cov_prefix=()
fi
# Install target binaries
mkdir -p /tmp/neon/bin/
binaries=$(
"${cov_prefix[@]}" cargo metadata --format-version=1 --no-deps |
${cov_prefix} cargo metadata $CARGO_FEATURES --format-version=1 --no-deps |
jq -r '.packages[].targets[] | select(.kind | index("bin")) | .name'
)
test_exe_paths=$(
"${cov_prefix[@]}" cargo test --message-format=json --no-run |
jq -r '.executable | select(. != null)'
)
mkdir -p /tmp/neon/bin/
mkdir -p /tmp/neon/test_bin/
mkdir -p /tmp/neon/etc/
# Keep bloated coverage data files away from the rest of the artifact
mkdir -p /tmp/coverage/
# Install target binaries
for bin in $binaries; do
SRC=target/$BUILD_TYPE/$bin
DST=/tmp/neon/bin/$bin
@@ -171,9 +119,14 @@ jobs:
# Install test executables and write list of all binaries (for code coverage)
if [[ $BUILD_TYPE == "debug" ]]; then
for bin in $binaries; do
echo "/tmp/neon/bin/$bin" >> /tmp/coverage/binaries.list
done
# Keep bloated coverage data files away from the rest of the artifact
mkdir -p /tmp/coverage/
mkdir -p /tmp/neon/test_bin/
test_exe_paths=$(
${cov_prefix} cargo test $CARGO_FLAGS --message-format=json --no-run |
jq -r '.executable | select(. != null)'
)
for bin in $test_exe_paths; do
SRC=$bin
DST=/tmp/neon/test_bin/$(basename $bin)
@@ -183,10 +136,17 @@ jobs:
strip "$SRC" -o "$DST"
echo "$DST" >> /tmp/coverage/binaries.list
done
for bin in $binaries; do
echo "/tmp/neon/bin/$bin" >> /tmp/coverage/binaries.list
done
fi
- name: Install postgres binaries
run: cp -a tmp_install /tmp/neon/pg_install
- name: Prepare neon artifact
run: tar -C /tmp/neon/ -czf ./neon.tgz .
run: ZSTD_NBTHREADS=0 tar -C /tmp/neon/ -cf ./neon.tar.zst --zstd .
- name: Upload neon binaries
uses: actions/upload-artifact@v3
@@ -194,7 +154,7 @@ jobs:
retention-days: 7
if-no-files-found: error
name: neon-${{ runner.os }}-${{ matrix.build_type }}-${{ matrix.rust_toolchain }}-artifact
path: ./neon.tgz
path: ./neon.tar.zst
# XXX: keep this after the binaries.list is formed, so the coverage can properly work later
- name: Merge and upload coverage data
@@ -308,7 +268,7 @@ jobs:
!~/.cargo/registry/src
~/.cargo/git/
target/
key: v2-${{ runner.os }}-${{ matrix.build_type }}-cargo-${{ matrix.rust_toolchain }}-${{ hashFiles('Cargo.lock') }}
key: v3-${{ runner.os }}-${{ matrix.build_type }}-cargo-${{ matrix.rust_toolchain }}-${{ hashFiles('Cargo.lock') }}
- name: Get Neon artifact for restoration
uses: actions/download-artifact@v3
@@ -319,7 +279,7 @@ jobs:
- name: Extract Neon artifact
run: |
mkdir -p /tmp/neon/
tar -xf ./neon-artifact/neon.tgz -C /tmp/neon/
tar -xf ./neon-artifact/neon.tar.zst -C /tmp/neon/
rm -rf ./neon-artifact/
- name: Restore coverage data
@@ -557,7 +517,7 @@ jobs:
if [[ "$GITHUB_REF_NAME" == "main" ]]; then
STAGING='{"env_name": "staging", "proxy_job": "neon-proxy", "proxy_config": "staging.proxy", "kubeconfig_secret": "STAGING_KUBECONFIG_DATA"}'
NEON_STRESS='{"env_name": "neon-stress", "proxy_job": "neon-stress-proxy", "proxy_config": "neon-stress.proxy", "kubeconfig_secret": "NEON_STRESS_KUBECONFIG_DATA"}'
echo "::set-output name=include::[$STAGING, $NEON_STRESS]"
echo "::set-output name=include::[$STAGING]"
elif [[ "$GITHUB_REF_NAME" == "release" ]]; then
PRODUCTION='{"env_name": "production", "proxy_job": "neon-proxy", "proxy_config": "production.proxy", "kubeconfig_secret": "PRODUCTION_KUBECONFIG_DATA"}'
echo "::set-output name=include::[$PRODUCTION]"

View File

@@ -101,7 +101,7 @@ jobs:
!~/.cargo/registry/src
~/.cargo/git
target
key: ${{ runner.os }}-cargo-${{ hashFiles('./Cargo.lock') }}-rust-${{ matrix.rust_toolchain }}
key: v1-${{ runner.os }}-cargo-${{ hashFiles('./Cargo.lock') }}-rust-${{ matrix.rust_toolchain }}
- name: Run cargo clippy
run: ./run_clippy.sh

8
Cargo.lock generated
View File

@@ -467,7 +467,6 @@ dependencies = [
"clap 3.2.12",
"env_logger",
"hyper",
"libc",
"log",
"postgres",
"regex",
@@ -517,7 +516,6 @@ dependencies = [
"tar",
"thiserror",
"toml",
"url",
"utils",
"workspace_hack",
]
@@ -1604,7 +1602,6 @@ version = "0.1.0"
dependencies = [
"lazy_static",
"libc",
"once_cell",
"prometheus",
"workspace_hack",
]
@@ -1677,7 +1674,6 @@ dependencies = [
"git-version",
"pageserver",
"postgres",
"postgres_ffi",
"safekeeper",
"serde_json",
"utils",
@@ -1905,7 +1901,6 @@ dependencies = [
"thiserror",
"tokio",
"tokio-postgres",
"tokio-stream",
"toml_edit",
"tracing",
"url",
@@ -2764,7 +2759,6 @@ dependencies = [
"daemonize",
"etcd_broker",
"fs2",
"futures",
"git-version",
"hex",
"humantime",
@@ -2784,12 +2778,10 @@ dependencies = [
"tempfile",
"tokio",
"tokio-postgres",
"tokio-util",
"toml_edit",
"tracing",
"url",
"utils",
"walkdir",
"workspace_hack",
]

View File

@@ -17,6 +17,10 @@ RUN set -e \
FROM neondatabase/rust:1.58 AS build
ARG GIT_VERSION=local
# Enable https://github.com/paritytech/cachepot to cache Rust crates' compilation results in Docker builds.
# Set up cachepot to use an AWS S3 bucket for cache results, to reuse it between `docker build` invocations.
# cachepot falls back to local filesystem if S3 is misconfigured, not failing the build.
ARG RUSTC_WRAPPER=cachepot
ARG CACHEPOT_BUCKET=zenith-rust-cachepot
ARG AWS_ACCESS_KEY_ID
ARG AWS_SECRET_ACCESS_KEY

View File

@@ -1,7 +1,11 @@
# First transient image to build compute_tools binaries
# NB: keep in sync with rust image version in .circle/config.yml
# NB: keep in sync with rust image version in .github/workflows/build_and_test.yml
FROM neondatabase/rust:1.58 AS rust-build
# Enable https://github.com/paritytech/cachepot to cache Rust crates' compilation results in Docker builds.
# Set up cachepot to use an AWS S3 bucket for cache results, to reuse it between `docker build` invocations.
# cachepot falls back to local filesystem if S3 is misconfigured, not failing the build.
ARG RUSTC_WRAPPER=cachepot
ARG CACHEPOT_BUCKET=zenith-rust-cachepot
ARG AWS_ACCESS_KEY_ID
ARG AWS_SECRET_ACCESS_KEY

View File

@@ -4,7 +4,6 @@ version = "0.1.0"
edition = "2021"
[dependencies]
libc = "0.2"
anyhow = "1.0"
chrono = "0.4"
clap = "3.0"

View File

@@ -14,7 +14,6 @@ regex = "1"
anyhow = "1.0"
thiserror = "1"
nix = "0.23"
url = "2.2.2"
reqwest = { version = "0.11", default-features = false, features = ["blocking", "json", "rustls-tls"] }
pageserver = { path = "../pageserver" }

View File

@@ -304,10 +304,9 @@ impl SafekeeperNode {
Ok(self
.http_request(
Method::POST,
format!("{}/{}", self.http_base_url, "timeline"),
format!("{}/tenant/{}/timeline", self.http_base_url, tenant_id),
)
.json(&TimelineCreateRequest {
tenant_id,
timeline_id,
peer_ids,
})

View File

@@ -7,5 +7,4 @@ edition = "2021"
prometheus = {version = "0.13", default_features=false, features = ["process"]} # removes protobuf dependency
libc = "0.2"
lazy_static = "1.4"
once_cell = "1.8.0"
workspace_hack = { version = "0.1", path = "../../workspace_hack" }

View File

@@ -3,6 +3,9 @@
//! Otherwise, we might not see all metrics registered via
//! a default registry.
use lazy_static::lazy_static;
use prometheus::core::{AtomicU64, GenericGauge, GenericGaugeVec};
pub use prometheus::opts;
pub use prometheus::register;
pub use prometheus::{core, default_registry, proto};
pub use prometheus::{exponential_buckets, linear_buckets};
pub use prometheus::{register_gauge, Gauge};
@@ -18,6 +21,17 @@ pub use prometheus::{Encoder, TextEncoder};
mod wrappers;
pub use wrappers::{CountedReader, CountedWriter};
pub type UIntGauge = GenericGauge<AtomicU64>;
pub type UIntGaugeVec = GenericGaugeVec<AtomicU64>;
#[macro_export]
macro_rules! register_uint_gauge_vec {
($NAME:expr, $HELP:expr, $LABELS_NAMES:expr $(,)?) => {{
let gauge_vec = UIntGaugeVec::new($crate::opts!($NAME, $HELP), $LABELS_NAMES).unwrap();
$crate::register(Box::new(gauge_vec.clone())).map(|_| gauge_vec)
}};
}
/// Gathers all Prometheus metrics and records the I/O stats just before that.
///
/// Metrics gathering is a relatively simple and standalone operation, so

View File

@@ -49,12 +49,12 @@ fn main() {
// Finding the location of C headers for the Postgres server:
// - if POSTGRES_INSTALL_DIR is set look into it, otherwise look into `<project_root>/tmp_install`
// - if there's a `bin/pg_config` file use it for getting include server, otherwise use `<project_root>/tmp_install/include/postgresql/server`
let mut pg_install_dir: PathBuf;
if let Some(postgres_install_dir) = env::var_os("POSTGRES_INSTALL_DIR") {
pg_install_dir = postgres_install_dir.into();
let mut pg_install_dir = if let Some(postgres_install_dir) = env::var_os("POSTGRES_INSTALL_DIR")
{
postgres_install_dir.into()
} else {
pg_install_dir = PathBuf::from("tmp_install")
}
PathBuf::from("tmp_install")
};
if pg_install_dir.is_relative() {
let cwd = env::current_dir().unwrap();

View File

@@ -47,10 +47,12 @@ pub enum FeStartupPacket {
StartupMessage {
major_version: u32,
minor_version: u32,
params: HashMap<String, String>,
params: StartupMessageParams,
},
}
pub type StartupMessageParams = HashMap<String, String>;
#[derive(Debug, Hash, PartialEq, Eq, Clone, Copy)]
pub struct CancelKeyData {
pub backend_pid: i32,

View File

@@ -15,6 +15,5 @@ git-version = "0.3.5"
pageserver = { path = "../pageserver" }
control_plane = { path = "../control_plane" }
safekeeper = { path = "../safekeeper" }
postgres_ffi = { path = "../libs/postgres_ffi" }
utils = { path = "../libs/utils" }
workspace_hack = { version = "0.1", path = "../workspace_hack" }

View File

@@ -29,7 +29,6 @@ postgres-types = { git = "https://github.com/zenithdb/rust-postgres.git", rev="d
postgres-protocol = { git = "https://github.com/zenithdb/rust-postgres.git", rev="d052ee8b86fff9897c77b0fe89ea9daba0e1fa38" }
postgres = { git = "https://github.com/zenithdb/rust-postgres.git", rev="d052ee8b86fff9897c77b0fe89ea9daba0e1fa38" }
tokio-postgres = { git = "https://github.com/zenithdb/rust-postgres.git", rev="d052ee8b86fff9897c77b0fe89ea9daba0e1fa38" }
tokio-stream = "0.1.8"
anyhow = { version = "1.0", features = ["backtrace"] }
crc32c = "0.6.0"
thiserror = "1.0"

View File

@@ -23,8 +23,7 @@ use tar::{Builder, EntryType, Header};
use tracing::*;
use crate::reltag::{RelTag, SlruKind};
use crate::repository::Timeline;
use crate::DatadirTimelineImpl;
use crate::DatadirTimeline;
use postgres_ffi::xlog_utils::*;
use postgres_ffi::*;
use utils::lsn::Lsn;
@@ -32,12 +31,13 @@ use utils::lsn::Lsn;
/// This is short-living object only for the time of tarball creation,
/// created mostly to avoid passing a lot of parameters between various functions
/// used for constructing tarball.
pub struct Basebackup<'a, W>
pub struct Basebackup<'a, W, T>
where
W: Write,
T: DatadirTimeline,
{
ar: Builder<AbortableWrite<W>>,
timeline: &'a Arc<DatadirTimelineImpl>,
timeline: &'a Arc<T>,
pub lsn: Lsn,
prev_record_lsn: Lsn,
full_backup: bool,
@@ -52,17 +52,18 @@ where
// * When working without safekeepers. In this situation it is important to match the lsn
// we are taking basebackup on with the lsn that is used in pageserver's walreceiver
// to start the replication.
impl<'a, W> Basebackup<'a, W>
impl<'a, W, T> Basebackup<'a, W, T>
where
W: Write,
T: DatadirTimeline,
{
pub fn new(
write: W,
timeline: &'a Arc<DatadirTimelineImpl>,
timeline: &'a Arc<T>,
req_lsn: Option<Lsn>,
prev_lsn: Option<Lsn>,
full_backup: bool,
) -> Result<Basebackup<'a, W>> {
) -> Result<Basebackup<'a, W, T>> {
// Compute postgres doesn't have any previous WAL files, but the first
// record that it's going to write needs to include the LSN of the
// previous record (xl_prev). We include prev_record_lsn in the
@@ -79,13 +80,13 @@ where
let (backup_prev, backup_lsn) = if let Some(req_lsn) = req_lsn {
// Backup was requested at a particular LSN. Wait for it to arrive.
info!("waiting for {}", req_lsn);
timeline.tline.wait_lsn(req_lsn)?;
timeline.wait_lsn(req_lsn)?;
// If the requested point is the end of the timeline, we can
// provide prev_lsn. (get_last_record_rlsn() might return it as
// zero, though, if no WAL has been generated on this timeline
// yet.)
let end_of_timeline = timeline.tline.get_last_record_rlsn();
let end_of_timeline = timeline.get_last_record_rlsn();
if req_lsn == end_of_timeline.last {
(end_of_timeline.prev, req_lsn)
} else {
@@ -93,7 +94,7 @@ where
}
} else {
// Backup was requested at end of the timeline.
let end_of_timeline = timeline.tline.get_last_record_rlsn();
let end_of_timeline = timeline.get_last_record_rlsn();
(end_of_timeline.prev, end_of_timeline.last)
};
@@ -371,7 +372,7 @@ where
// add zenith.signal file
let mut zenith_signal = String::new();
if self.prev_record_lsn == Lsn(0) {
if self.lsn == self.timeline.tline.get_ancestor_lsn() {
if self.lsn == self.timeline.get_ancestor_lsn() {
write!(zenith_signal, "PREV LSN: none")?;
} else {
write!(zenith_signal, "PREV LSN: invalid")?;
@@ -402,9 +403,10 @@ where
}
}
impl<'a, W> Drop for Basebackup<'a, W>
impl<'a, W, T> Drop for Basebackup<'a, W, T>
where
W: Write,
T: DatadirTimeline,
{
/// If the basebackup was not finished, prevent the Archive::drop() from
/// writing the end-of-archive marker.

View File

@@ -78,6 +78,11 @@ paths:
schema:
type: string
description: Controls calculation of current_logical_size_non_incremental
- name: include-non-incremental-physical-size
in: query
schema:
type: string
description: Controls calculation of current_physical_size_non_incremental
get:
description: Get timelines for tenant
responses:
@@ -136,6 +141,11 @@ paths:
schema:
type: string
description: Controls calculation of current_logical_size_non_incremental
- name: include-non-incremental-physical-size
in: query
schema:
type: string
description: Controls calculation of current_physical_size_non_incremental
responses:
"200":
description: TimelineInfo
@@ -671,8 +681,12 @@ components:
format: hex
current_logical_size:
type: integer
current_physical_size:
type: integer
current_logical_size_non_incremental:
type: integer
current_physical_size_non_incremental:
type: integer
WalReceiverEntry:
type: object

View File

@@ -113,10 +113,17 @@ async fn timeline_create_handler(mut request: Request<Body>) -> Result<Response<
async fn timeline_list_handler(request: Request<Body>) -> Result<Response<Body>, ApiError> {
let tenant_id: ZTenantId = parse_request_param(&request, "tenant_id")?;
check_permission(&request, Some(tenant_id))?;
let include_non_incremental_logical_size = get_include_non_incremental_logical_size(&request);
let include_non_incremental_logical_size =
query_param_present(&request, "include-non-incremental-logical-size");
let include_non_incremental_physical_size =
query_param_present(&request, "include-non-incremental-physical-size");
let local_timeline_infos = tokio::task::spawn_blocking(move || {
let _enter = info_span!("timeline_list", tenant = %tenant_id).entered();
crate::timelines::get_local_timelines(tenant_id, include_non_incremental_logical_size)
crate::timelines::get_local_timelines(
tenant_id,
include_non_incremental_logical_size,
include_non_incremental_physical_size,
)
})
.await
.map_err(ApiError::from_err)??;
@@ -145,17 +152,15 @@ async fn timeline_list_handler(request: Request<Body>) -> Result<Response<Body>,
json_response(StatusCode::OK, response_data)
}
// Gate non incremental logical size calculation behind a flag
// after pgbench -i -s100 calculation took 28ms so if multiplied by the number of timelines
// and tenants it can take noticeable amount of time. Also the value currently used only in tests
fn get_include_non_incremental_logical_size(request: &Request<Body>) -> bool {
/// Checks if a query param is present in the request's URL
fn query_param_present(request: &Request<Body>, param: &str) -> bool {
request
.uri()
.query()
.map(|v| {
url::form_urlencoded::parse(v.as_bytes())
.into_owned()
.any(|(param, _)| param == "include-non-incremental-logical-size")
.any(|(p, _)| p == param)
})
.unwrap_or(false)
}
@@ -165,7 +170,10 @@ async fn timeline_detail_handler(request: Request<Body>) -> Result<Response<Body
check_permission(&request, Some(tenant_id))?;
let timeline_id: ZTimelineId = parse_request_param(&request, "timeline_id")?;
let include_non_incremental_logical_size = get_include_non_incremental_logical_size(&request);
let include_non_incremental_logical_size =
query_param_present(&request, "include-non-incremental-logical-size");
let include_non_incremental_physical_size =
query_param_present(&request, "include-non-incremental-physical-size");
let (local_timeline_info, remote_timeline_info) = async {
// any error here will render local timeline as None
@@ -181,6 +189,7 @@ async fn timeline_detail_handler(request: Request<Body>) -> Result<Response<Body
timeline_id,
timeline,
include_non_incremental_logical_size,
include_non_incremental_physical_size,
)
})
.transpose()?

View File

@@ -13,9 +13,8 @@ use walkdir::WalkDir;
use crate::pgdatadir_mapping::*;
use crate::reltag::{RelTag, SlruKind};
use crate::repository::Repository;
use crate::repository::Timeline;
use crate::walingest::WalIngest;
use crate::walrecord::DecodedWALRecord;
use postgres_ffi::relfile_utils::*;
use postgres_ffi::waldecoder::*;
use postgres_ffi::xlog_utils::*;
@@ -29,16 +28,16 @@ use utils::lsn::Lsn;
/// This is currently only used to import a cluster freshly created by initdb.
/// The code that deals with the checkpoint would not work right if the
/// cluster was not shut down cleanly.
pub fn import_timeline_from_postgres_datadir<R: Repository>(
pub fn import_timeline_from_postgres_datadir<T: DatadirTimeline>(
path: &Path,
tline: &mut DatadirTimeline<R>,
tline: &T,
lsn: Lsn,
) -> Result<()> {
let mut pg_control: Option<ControlFileData> = None;
// TODO this shoud be start_lsn, which is not necessarily equal to end_lsn (aka lsn)
// Then fishing out pg_control would be unnecessary
let mut modification = tline.begin_modification(lsn);
let mut modification = tline.begin_modification();
modification.init_empty()?;
// Import all but pg_wal
@@ -57,12 +56,12 @@ pub fn import_timeline_from_postgres_datadir<R: Repository>(
if let Some(control_file) = import_file(&mut modification, relative_path, file, len)? {
pg_control = Some(control_file);
}
modification.flush()?;
modification.flush(lsn)?;
}
}
// We're done importing all the data files.
modification.commit()?;
modification.commit(lsn)?;
// We expect the Postgres server to be shut down cleanly.
let pg_control = pg_control.context("pg_control file not found")?;
@@ -89,8 +88,8 @@ pub fn import_timeline_from_postgres_datadir<R: Repository>(
}
// subroutine of import_timeline_from_postgres_datadir(), to load one relation file.
fn import_rel<R: Repository, Reader: Read>(
modification: &mut DatadirModification<R>,
fn import_rel<T: DatadirTimeline, Reader: Read>(
modification: &mut DatadirModification<T>,
path: &Path,
spcoid: Oid,
dboid: Oid,
@@ -169,8 +168,8 @@ fn import_rel<R: Repository, Reader: Read>(
/// Import an SLRU segment file
///
fn import_slru<R: Repository, Reader: Read>(
modification: &mut DatadirModification<R>,
fn import_slru<T: DatadirTimeline, Reader: Read>(
modification: &mut DatadirModification<T>,
slru: SlruKind,
path: &Path,
mut reader: Reader,
@@ -225,9 +224,9 @@ fn import_slru<R: Repository, Reader: Read>(
/// Scan PostgreSQL WAL files in given directory and load all records between
/// 'startpoint' and 'endpoint' into the repository.
fn import_wal<R: Repository>(
fn import_wal<T: DatadirTimeline>(
walpath: &Path,
tline: &mut DatadirTimeline<R>,
tline: &T,
startpoint: Lsn,
endpoint: Lsn,
) -> Result<()> {
@@ -268,9 +267,11 @@ fn import_wal<R: Repository>(
waldecoder.feed_bytes(&buf);
let mut nrecords = 0;
let mut modification = tline.begin_modification();
let mut decoded = DecodedWALRecord::default();
while last_lsn <= endpoint {
if let Some((lsn, recdata)) = waldecoder.poll_decode()? {
walingest.ingest_record(tline, recdata, lsn)?;
walingest.ingest_record(recdata, lsn, &mut modification, &mut decoded)?;
last_lsn = lsn;
nrecords += 1;
@@ -294,13 +295,13 @@ fn import_wal<R: Repository>(
Ok(())
}
pub fn import_basebackup_from_tar<R: Repository, Reader: Read>(
tline: &mut DatadirTimeline<R>,
pub fn import_basebackup_from_tar<T: DatadirTimeline, Reader: Read>(
tline: &T,
reader: Reader,
base_lsn: Lsn,
) -> Result<()> {
info!("importing base at {}", base_lsn);
let mut modification = tline.begin_modification(base_lsn);
let mut modification = tline.begin_modification();
modification.init_empty()?;
let mut pg_control: Option<ControlFileData> = None;
@@ -318,7 +319,7 @@ pub fn import_basebackup_from_tar<R: Repository, Reader: Read>(
// We found the pg_control file.
pg_control = Some(res);
}
modification.flush()?;
modification.flush(base_lsn)?;
}
tar::EntryType::Directory => {
debug!("directory {:?}", file_path);
@@ -332,12 +333,12 @@ pub fn import_basebackup_from_tar<R: Repository, Reader: Read>(
// sanity check: ensure that pg_control is loaded
let _pg_control = pg_control.context("pg_control file not found")?;
modification.commit()?;
modification.commit(base_lsn)?;
Ok(())
}
pub fn import_wal_from_tar<R: Repository, Reader: Read>(
tline: &mut DatadirTimeline<R>,
pub fn import_wal_from_tar<T: DatadirTimeline, Reader: Read>(
tline: &T,
reader: Reader,
start_lsn: Lsn,
end_lsn: Lsn,
@@ -384,9 +385,11 @@ pub fn import_wal_from_tar<R: Repository, Reader: Read>(
waldecoder.feed_bytes(&bytes[offset..]);
let mut modification = tline.begin_modification();
let mut decoded = DecodedWALRecord::default();
while last_lsn <= end_lsn {
if let Some((lsn, recdata)) = waldecoder.poll_decode()? {
walingest.ingest_record(tline, recdata, lsn)?;
walingest.ingest_record(recdata, lsn, &mut modification, &mut decoded)?;
last_lsn = lsn;
debug!("imported record at {} (end {})", lsn, end_lsn);
@@ -415,8 +418,8 @@ pub fn import_wal_from_tar<R: Repository, Reader: Read>(
Ok(())
}
pub fn import_file<R: Repository, Reader: Read>(
modification: &mut DatadirModification<R>,
pub fn import_file<T: DatadirTimeline, Reader: Read>(
modification: &mut DatadirModification<T>,
file_path: &Path,
reader: Reader,
len: usize,
@@ -535,7 +538,7 @@ pub fn import_file<R: Repository, Reader: Read>(
// zenith.signal is not necessarily the last file, that we handle
// but it is ok to call `finish_write()`, because final `modification.commit()`
// will update lsn once more to the final one.
let writer = modification.tline.tline.writer();
let writer = modification.tline.writer();
writer.finish_write(prev_lsn);
debug!("imported zenith signal {}", prev_lsn);

File diff suppressed because it is too large Load Diff

View File

@@ -316,6 +316,18 @@ impl Layer for DeltaLayer {
}
}
fn key_iter<'a>(&'a self) -> Box<dyn Iterator<Item = (Key, Lsn, u64)> + 'a> {
let inner = match self.load() {
Ok(inner) => inner,
Err(e) => panic!("Failed to load a delta layer: {e:?}"),
};
match DeltaKeyIter::new(inner) {
Ok(iter) => Box::new(iter),
Err(e) => panic!("Layer index is corrupted: {e:?}"),
}
}
fn delete(&self) -> Result<()> {
// delete underlying file
fs::remove_file(self.path())?;
@@ -660,11 +672,21 @@ impl DeltaLayerWriter {
/// The values must be appended in key, lsn order.
///
pub fn put_value(&mut self, key: Key, lsn: Lsn, val: Value) -> Result<()> {
self.put_value_bytes(key, lsn, &Value::ser(&val)?, val.will_init())
}
pub fn put_value_bytes(
&mut self,
key: Key,
lsn: Lsn,
val: &[u8],
will_init: bool,
) -> Result<()> {
assert!(self.lsn_range.start <= lsn);
let off = self.blob_writer.write_blob(&Value::ser(&val)?)?;
let off = self.blob_writer.write_blob(val)?;
let blob_ref = BlobRef::new(off, val.will_init());
let blob_ref = BlobRef::new(off, will_init);
let delta_key = DeltaKey::from_key_lsn(&key, lsn);
self.tree.append(&delta_key.0, blob_ref.0)?;
@@ -822,3 +844,75 @@ impl<'a> DeltaValueIter<'a> {
}
}
}
///
/// Iterator over all keys stored in a delta layer
///
/// FIXME: This creates a Vector to hold all keys.
/// That takes up quite a lot of memory. Should do this in a more streaming
/// fashion.
///
struct DeltaKeyIter {
all_keys: Vec<(DeltaKey, u64)>,
next_idx: usize,
}
impl Iterator for DeltaKeyIter {
type Item = (Key, Lsn, u64);
fn next(&mut self) -> Option<Self::Item> {
if self.next_idx < self.all_keys.len() {
let (delta_key, size) = &self.all_keys[self.next_idx];
let key = delta_key.key();
let lsn = delta_key.lsn();
self.next_idx += 1;
Some((key, lsn, *size))
} else {
None
}
}
}
impl<'a> DeltaKeyIter {
fn new(inner: RwLockReadGuard<'a, DeltaLayerInner>) -> Result<Self> {
let file = inner.file.as_ref().unwrap();
let tree_reader = DiskBtreeReader::<_, DELTA_KEY_SIZE>::new(
inner.index_start_blk,
inner.index_root_blk,
file,
);
let mut all_keys: Vec<(DeltaKey, u64)> = Vec::new();
tree_reader.visit(
&[0u8; DELTA_KEY_SIZE],
VisitDirection::Forwards,
|key, value| {
let delta_key = DeltaKey::from_slice(key);
let pos = BlobRef(value).pos();
if let Some(last) = all_keys.last_mut() {
if last.0.key() == delta_key.key() {
return true;
} else {
// subtract offset of new key BLOB and first blob of this key
// to get total size if values associated with this key
let first_pos = last.1;
last.1 = pos - first_pos;
}
}
all_keys.push((delta_key, pos));
true
},
)?;
if let Some(last) = all_keys.last_mut() {
// Last key occupies all space till end of layer
last.1 = std::fs::metadata(&file.file.path)?.len() - last.1;
}
let iter = DeltaKeyIter {
all_keys,
next_idx: 0,
};
Ok(iter)
}
}

View File

@@ -43,7 +43,7 @@ pub struct EphemeralFile {
_timelineid: ZTimelineId,
file: Arc<VirtualFile>,
size: u64,
pub size: u64,
}
impl EphemeralFile {

View File

@@ -15,6 +15,7 @@ use crate::layered_repository::storage_layer::{
use crate::repository::{Key, Value};
use crate::walrecord;
use anyhow::{bail, ensure, Result};
use std::cell::RefCell;
use std::collections::HashMap;
use tracing::*;
use utils::{
@@ -30,6 +31,12 @@ use std::ops::Range;
use std::path::PathBuf;
use std::sync::RwLock;
thread_local! {
/// A buffer for serializing object during [`InMemoryLayer::put_value`].
/// This buffer is reused for each serialization to avoid additional malloc calls.
static SER_BUFFER: RefCell<Vec<u8>> = RefCell::new(Vec::new());
}
pub struct InMemoryLayer {
conf: &'static PageServerConf,
tenantid: ZTenantId,
@@ -233,6 +240,14 @@ impl Layer for InMemoryLayer {
}
impl InMemoryLayer {
///
/// Get layer size on the disk
///
pub fn size(&self) -> Result<u64> {
let inner = self.inner.read().unwrap();
Ok(inner.file.size)
}
///
/// Create a new, empty, in-memory layer
///
@@ -270,10 +285,17 @@ impl InMemoryLayer {
pub fn put_value(&self, key: Key, lsn: Lsn, val: &Value) -> Result<()> {
trace!("put_value key {} at {}/{}", key, self.timelineid, lsn);
let mut inner = self.inner.write().unwrap();
inner.assert_writeable();
let off = inner.file.write_blob(&Value::ser(val)?)?;
let off = {
SER_BUFFER.with(|x| -> Result<_> {
let mut buf = x.borrow_mut();
buf.clear();
val.ser_into(&mut (*buf))?;
let off = inner.file.write_blob(&buf)?;
Ok(off)
})?
};
let vec_map = inner.index.entry(key).or_default();
let old = vec_map.append_or_update_last(lsn, off).unwrap().0;
@@ -342,8 +364,8 @@ impl InMemoryLayer {
// Write all page versions
for (lsn, pos) in vec_map.as_slice() {
cursor.read_blob_into_buf(*pos, &mut buf)?;
let val = Value::des(&buf)?;
delta_layer_writer.put_value(key, *lsn, val)?;
let will_init = Value::des(&buf)?.will_init();
delta_layer_writer.put_value_bytes(key, *lsn, &buf, will_init)?;
}
}

View File

@@ -10,9 +10,9 @@
//! corresponding files are written to disk.
//!
use crate::layered_repository::inmemory_layer::InMemoryLayer;
use crate::layered_repository::storage_layer::Layer;
use crate::layered_repository::storage_layer::{range_eq, range_overlaps};
use crate::layered_repository::InMemoryLayer;
use crate::repository::Key;
use anyhow::Result;
use lazy_static::lazy_static;

View File

@@ -139,6 +139,12 @@ pub trait Layer: Send + Sync {
/// Iterate through all keys and values stored in the layer
fn iter(&self) -> Box<dyn Iterator<Item = Result<(Key, Lsn, Value)>> + '_>;
/// Iterate through all keys stored in the layer. Returns key, lsn and value size
/// It is used only for compaction and so is currently implemented only for DeltaLayer
fn key_iter(&self) -> Box<dyn Iterator<Item = (Key, Lsn, u64)> + '_> {
panic!("Not implemented")
}
/// Permanently remove this layer from disk.
fn delete(&self) -> Result<()>;

File diff suppressed because it is too large Load Diff

View File

@@ -63,8 +63,7 @@ pub enum CheckpointConfig {
}
pub type RepositoryImpl = LayeredRepository;
pub type DatadirTimelineImpl = DatadirTimeline<RepositoryImpl>;
pub type TimelineImpl = <LayeredRepository as repository::Repository>::Timeline;
pub fn shutdown_pageserver(exit_code: i32) {
// Shut down the libpq endpoint thread. This prevents new connections from

View File

@@ -30,7 +30,6 @@ use utils::{
use crate::basebackup;
use crate::config::{PageServerConf, ProfilingConfig};
use crate::import_datadir::{import_basebackup_from_tar, import_wal_from_tar};
use crate::layered_repository::LayeredRepository;
use crate::pgdatadir_mapping::{DatadirTimeline, LsnForTimestamp};
use crate::profiling::profpoint_start;
use crate::reltag::RelTag;
@@ -555,9 +554,6 @@ impl PageServerHandler {
info!("creating new timeline");
let repo = tenant_mgr::get_repository_for_tenant(tenant_id)?;
let timeline = repo.create_empty_timeline(timeline_id, base_lsn)?;
let repartition_distance = repo.get_checkpoint_distance();
let mut datadir_timeline =
DatadirTimeline::<LayeredRepository>::new(timeline, repartition_distance);
// TODO mark timeline as not ready until it reaches end_lsn.
// We might have some wal to import as well, and we should prevent compute
@@ -573,7 +569,7 @@ impl PageServerHandler {
info!("importing basebackup");
pgb.write_message(&BeMessage::CopyInResponse)?;
let reader = CopyInReader::new(pgb);
import_basebackup_from_tar(&mut datadir_timeline, reader, base_lsn)?;
import_basebackup_from_tar(&*timeline, reader, base_lsn)?;
// TODO check checksum
// Meanwhile you can verify client-side by taking fullbackup
@@ -583,7 +579,7 @@ impl PageServerHandler {
// Flush data to disk, then upload to s3
info!("flushing layers");
datadir_timeline.tline.checkpoint(CheckpointConfig::Flush)?;
timeline.checkpoint(CheckpointConfig::Flush)?;
info!("done");
Ok(())
@@ -605,10 +601,6 @@ impl PageServerHandler {
let timeline = repo.get_timeline_load(timeline_id)?;
ensure!(timeline.get_last_record_lsn() == start_lsn);
let repartition_distance = repo.get_checkpoint_distance();
let mut datadir_timeline =
DatadirTimeline::<LayeredRepository>::new(timeline, repartition_distance);
// TODO leave clean state on error. For now you can use detach to clean
// up broken state from a failed import.
@@ -616,16 +608,16 @@ impl PageServerHandler {
info!("importing wal");
pgb.write_message(&BeMessage::CopyInResponse)?;
let reader = CopyInReader::new(pgb);
import_wal_from_tar(&mut datadir_timeline, reader, start_lsn, end_lsn)?;
import_wal_from_tar(&*timeline, reader, start_lsn, end_lsn)?;
// TODO Does it make sense to overshoot?
ensure!(datadir_timeline.tline.get_last_record_lsn() >= end_lsn);
ensure!(timeline.get_last_record_lsn() >= end_lsn);
// Flush data to disk, then upload to s3. No need for a forced checkpoint.
// We only want to persist the data, and it doesn't matter if it's in the
// shape of deltas or images.
info!("flushing layers");
datadir_timeline.tline.checkpoint(CheckpointConfig::Flush)?;
timeline.checkpoint(CheckpointConfig::Flush)?;
info!("done");
Ok(())
@@ -643,8 +635,8 @@ impl PageServerHandler {
/// In either case, if the page server hasn't received the WAL up to the
/// requested LSN yet, we will wait for it to arrive. The return value is
/// the LSN that should be used to look up the page versions.
fn wait_or_get_last_lsn<R: Repository>(
timeline: &DatadirTimeline<R>,
fn wait_or_get_last_lsn<T: DatadirTimeline>(
timeline: &T,
mut lsn: Lsn,
latest: bool,
latest_gc_cutoff_lsn: &RwLockReadGuard<Lsn>,
@@ -671,7 +663,7 @@ impl PageServerHandler {
if lsn <= last_record_lsn {
lsn = last_record_lsn;
} else {
timeline.tline.wait_lsn(lsn)?;
timeline.wait_lsn(lsn)?;
// Since we waited for 'lsn' to arrive, that is now the last
// record LSN. (Or close enough for our purposes; the
// last-record LSN can advance immediately after we return
@@ -681,7 +673,7 @@ impl PageServerHandler {
if lsn == Lsn(0) {
bail!("invalid LSN(0) in request");
}
timeline.tline.wait_lsn(lsn)?;
timeline.wait_lsn(lsn)?;
}
ensure!(
lsn >= **latest_gc_cutoff_lsn,
@@ -691,14 +683,14 @@ impl PageServerHandler {
Ok(lsn)
}
fn handle_get_rel_exists_request<R: Repository>(
fn handle_get_rel_exists_request<T: DatadirTimeline>(
&self,
timeline: &DatadirTimeline<R>,
timeline: &T,
req: &PagestreamExistsRequest,
) -> Result<PagestreamBeMessage> {
let _enter = info_span!("get_rel_exists", rel = %req.rel, req_lsn = %req.lsn).entered();
let latest_gc_cutoff_lsn = timeline.tline.get_latest_gc_cutoff_lsn();
let latest_gc_cutoff_lsn = timeline.get_latest_gc_cutoff_lsn();
let lsn = Self::wait_or_get_last_lsn(timeline, req.lsn, req.latest, &latest_gc_cutoff_lsn)?;
let exists = timeline.get_rel_exists(req.rel, lsn)?;
@@ -708,13 +700,13 @@ impl PageServerHandler {
}))
}
fn handle_get_nblocks_request<R: Repository>(
fn handle_get_nblocks_request<T: DatadirTimeline>(
&self,
timeline: &DatadirTimeline<R>,
timeline: &T,
req: &PagestreamNblocksRequest,
) -> Result<PagestreamBeMessage> {
let _enter = info_span!("get_nblocks", rel = %req.rel, req_lsn = %req.lsn).entered();
let latest_gc_cutoff_lsn = timeline.tline.get_latest_gc_cutoff_lsn();
let latest_gc_cutoff_lsn = timeline.get_latest_gc_cutoff_lsn();
let lsn = Self::wait_or_get_last_lsn(timeline, req.lsn, req.latest, &latest_gc_cutoff_lsn)?;
let n_blocks = timeline.get_rel_size(req.rel, lsn)?;
@@ -724,13 +716,13 @@ impl PageServerHandler {
}))
}
fn handle_db_size_request<R: Repository>(
fn handle_db_size_request<T: DatadirTimeline>(
&self,
timeline: &DatadirTimeline<R>,
timeline: &T,
req: &PagestreamDbSizeRequest,
) -> Result<PagestreamBeMessage> {
let _enter = info_span!("get_db_size", dbnode = %req.dbnode, req_lsn = %req.lsn).entered();
let latest_gc_cutoff_lsn = timeline.tline.get_latest_gc_cutoff_lsn();
let latest_gc_cutoff_lsn = timeline.get_latest_gc_cutoff_lsn();
let lsn = Self::wait_or_get_last_lsn(timeline, req.lsn, req.latest, &latest_gc_cutoff_lsn)?;
let total_blocks =
@@ -743,14 +735,14 @@ impl PageServerHandler {
}))
}
fn handle_get_page_at_lsn_request<R: Repository>(
fn handle_get_page_at_lsn_request<T: DatadirTimeline>(
&self,
timeline: &DatadirTimeline<R>,
timeline: &T,
req: &PagestreamGetPageRequest,
) -> Result<PagestreamBeMessage> {
let _enter = info_span!("get_page", rel = %req.rel, blkno = &req.blkno, req_lsn = %req.lsn)
.entered();
let latest_gc_cutoff_lsn = timeline.tline.get_latest_gc_cutoff_lsn();
let latest_gc_cutoff_lsn = timeline.get_latest_gc_cutoff_lsn();
let lsn = Self::wait_or_get_last_lsn(timeline, req.lsn, req.latest, &latest_gc_cutoff_lsn)?;
/*
// Add a 1s delay to some requests. The delayed causes the requests to
@@ -783,7 +775,7 @@ impl PageServerHandler {
// check that the timeline exists
let timeline = tenant_mgr::get_local_timeline_with_load(tenantid, timelineid)
.context("Cannot load local timeline")?;
let latest_gc_cutoff_lsn = timeline.tline.get_latest_gc_cutoff_lsn();
let latest_gc_cutoff_lsn = timeline.get_latest_gc_cutoff_lsn();
if let Some(lsn) = lsn {
timeline
.check_lsn_is_in_scope(lsn, &latest_gc_cutoff_lsn)
@@ -921,7 +913,7 @@ impl postgres_backend::Handler for PageServerHandler {
let timeline = tenant_mgr::get_local_timeline_with_load(tenantid, timelineid)
.context("Cannot load local timeline")?;
let end_of_timeline = timeline.tline.get_last_record_rlsn();
let end_of_timeline = timeline.get_last_record_rlsn();
pgb.write_message_noflush(&BeMessage::RowDescription(&[
RowDescriptor::text_col(b"prev_lsn"),
@@ -1139,7 +1131,7 @@ impl postgres_backend::Handler for PageServerHandler {
let timelineid = ZTimelineId::from_str(caps.get(2).unwrap().as_str())?;
let timeline = tenant_mgr::get_local_timeline_with_load(tenantid, timelineid)
.context("Couldn't load timeline")?;
timeline.tline.compact()?;
timeline.compact()?;
pgb.write_message_noflush(&SINGLE_COL_ROWDESC)?
.write_message_noflush(&BeMessage::CommandComplete(b"SELECT 1"))?;
@@ -1159,13 +1151,8 @@ impl postgres_backend::Handler for PageServerHandler {
let timeline = tenant_mgr::get_local_timeline_with_load(tenantid, timelineid)
.context("Cannot load local timeline")?;
timeline.tline.checkpoint(CheckpointConfig::Forced)?;
// Also compact it.
//
// FIXME: This probably shouldn't be part of a "checkpoint" command, but a
// separate operation. Update the tests if you change this.
timeline.tline.compact()?;
// Checkpoint the timeline and also compact it (due to `CheckpointConfig::Forced`).
timeline.checkpoint(CheckpointConfig::Forced)?;
pgb.write_message_noflush(&SINGLE_COL_ROWDESC)?
.write_message_noflush(&BeMessage::CommandComplete(b"SELECT 1"))?;

View File

@@ -6,10 +6,10 @@
//! walingest.rs handles a few things like implicit relation creation and extension.
//! Clarify that)
//!
use crate::keyspace::{KeyPartitioning, KeySpace, KeySpaceAccum};
use crate::keyspace::{KeySpace, KeySpaceAccum};
use crate::reltag::{RelTag, SlruKind};
use crate::repository::Timeline;
use crate::repository::*;
use crate::repository::{Repository, Timeline};
use crate::walrecord::ZenithWalRecord;
use anyhow::{bail, ensure, Result};
use bytes::{Buf, Bytes};
@@ -18,34 +18,12 @@ use postgres_ffi::{pg_constants, Oid, TransactionId};
use serde::{Deserialize, Serialize};
use std::collections::{HashMap, HashSet};
use std::ops::Range;
use std::sync::atomic::{AtomicIsize, Ordering};
use std::sync::{Arc, Mutex, RwLockReadGuard};
use tracing::{debug, error, trace, warn};
use tracing::{debug, trace, warn};
use utils::{bin_ser::BeSer, lsn::Lsn};
/// Block number within a relation or SLRU. This matches PostgreSQL's BlockNumber type.
pub type BlockNumber = u32;
pub struct DatadirTimeline<R>
where
R: Repository,
{
/// The underlying key-value store. Callers should not read or modify the
/// data in the underlying store directly. However, it is exposed to have
/// access to information like last-LSN, ancestor, and operations like
/// compaction.
pub tline: Arc<R::Timeline>,
/// When did we last calculate the partitioning?
partitioning: Mutex<(KeyPartitioning, Lsn)>,
/// Configuration: how often should the partitioning be recalculated.
repartition_threshold: u64,
/// Current logical size of the "datadir", at the last LSN.
current_logical_size: AtomicIsize,
}
#[derive(Debug)]
pub enum LsnForTimestamp {
Present(Lsn),
@@ -54,49 +32,50 @@ pub enum LsnForTimestamp {
NoData(Lsn),
}
impl<R: Repository> DatadirTimeline<R> {
pub fn new(tline: Arc<R::Timeline>, repartition_threshold: u64) -> Self {
DatadirTimeline {
tline,
partitioning: Mutex::new((KeyPartitioning::new(), Lsn(0))),
current_logical_size: AtomicIsize::new(0),
repartition_threshold,
}
}
/// (Re-)calculate the logical size of the database at the latest LSN.
///
/// This can be a slow operation.
pub fn init_logical_size(&self) -> Result<()> {
let last_lsn = self.tline.get_last_record_lsn();
self.current_logical_size.store(
self.get_current_logical_size_non_incremental(last_lsn)? as isize,
Ordering::SeqCst,
);
Ok(())
}
///
/// This trait provides all the functionality to store PostgreSQL relations, SLRUs,
/// and other special kinds of files, in a versioned key-value store. The
/// Timeline trait provides the key-value store.
///
/// This is a trait, so that we can easily include all these functions in a Timeline
/// implementation. You're not expected to have different implementations of this trait,
/// rather, this provides an interface and implementation, over Timeline.
///
/// If you wanted to store other kinds of data in the Neon repository, e.g.
/// flat files or MySQL, you would create a new trait like this, with all the
/// functions that make sense for the kind of data you're storing. For flat files,
/// for example, you might have a function like "fn read(path, offset, size)".
/// We might also have that situation in the future, to support multiple PostgreSQL
/// versions, if there are big changes in how the data is organized in the data
/// directory, or if new special files are introduced.
///
pub trait DatadirTimeline: Timeline {
/// Start ingesting a WAL record, or other atomic modification of
/// the timeline.
///
/// This provides a transaction-like interface to perform a bunch
/// of modifications atomically, all stamped with one LSN.
/// of modifications atomically.
///
/// To ingest a WAL record, call begin_modification(lsn) to get a
/// To ingest a WAL record, call begin_modification() to get a
/// DatadirModification object. Use the functions in the object to
/// modify the repository state, updating all the pages and metadata
/// that the WAL record affects. When you're done, call commit() to
/// commit the changes.
/// that the WAL record affects. When you're done, call commit(lsn) to
/// commit the changes. All the changes will be stamped with the specified LSN.
///
/// Calling commit(lsn) will flush all the changes and reset the state,
/// so the `DatadirModification` struct can be reused to perform the next modification.
///
/// Note that any pending modifications you make through the
/// modification object won't be visible to calls to the 'get' and list
/// functions of the timeline until you finish! And if you update the
/// same page twice, the last update wins.
///
pub fn begin_modification(&self, lsn: Lsn) -> DatadirModification<R> {
fn begin_modification(&self) -> DatadirModification<Self>
where
Self: Sized,
{
DatadirModification {
tline: self,
lsn,
pending_updates: HashMap::new(),
pending_deletions: Vec::new(),
pending_nblocks: 0,
@@ -108,7 +87,7 @@ impl<R: Repository> DatadirTimeline<R> {
//------------------------------------------------------------------------------
/// Look up given page version.
pub fn get_rel_page_at_lsn(&self, tag: RelTag, blknum: BlockNumber, lsn: Lsn) -> Result<Bytes> {
fn get_rel_page_at_lsn(&self, tag: RelTag, blknum: BlockNumber, lsn: Lsn) -> Result<Bytes> {
ensure!(tag.relnode != 0, "invalid relnode");
let nblocks = self.get_rel_size(tag, lsn)?;
@@ -121,11 +100,11 @@ impl<R: Repository> DatadirTimeline<R> {
}
let key = rel_block_to_key(tag, blknum);
self.tline.get(key, lsn)
self.get(key, lsn)
}
// Get size of a database in blocks
pub fn get_db_size(&self, spcnode: Oid, dbnode: Oid, lsn: Lsn) -> Result<usize> {
fn get_db_size(&self, spcnode: Oid, dbnode: Oid, lsn: Lsn) -> Result<usize> {
let mut total_blocks = 0;
let rels = self.list_rels(spcnode, dbnode, lsn)?;
@@ -138,7 +117,7 @@ impl<R: Repository> DatadirTimeline<R> {
}
/// Get size of a relation file
pub fn get_rel_size(&self, tag: RelTag, lsn: Lsn) -> Result<BlockNumber> {
fn get_rel_size(&self, tag: RelTag, lsn: Lsn) -> Result<BlockNumber> {
ensure!(tag.relnode != 0, "invalid relnode");
if (tag.forknum == pg_constants::FSM_FORKNUM
@@ -153,17 +132,17 @@ impl<R: Repository> DatadirTimeline<R> {
}
let key = rel_size_to_key(tag);
let mut buf = self.tline.get(key, lsn)?;
let mut buf = self.get(key, lsn)?;
Ok(buf.get_u32_le())
}
/// Does relation exist?
pub fn get_rel_exists(&self, tag: RelTag, lsn: Lsn) -> Result<bool> {
fn get_rel_exists(&self, tag: RelTag, lsn: Lsn) -> Result<bool> {
ensure!(tag.relnode != 0, "invalid relnode");
// fetch directory listing
let key = rel_dir_to_key(tag.spcnode, tag.dbnode);
let buf = self.tline.get(key, lsn)?;
let buf = self.get(key, lsn)?;
let dir = RelDirectory::des(&buf)?;
let exists = dir.rels.get(&(tag.relnode, tag.forknum)).is_some();
@@ -172,10 +151,10 @@ impl<R: Repository> DatadirTimeline<R> {
}
/// Get a list of all existing relations in given tablespace and database.
pub fn list_rels(&self, spcnode: Oid, dbnode: Oid, lsn: Lsn) -> Result<HashSet<RelTag>> {
fn list_rels(&self, spcnode: Oid, dbnode: Oid, lsn: Lsn) -> Result<HashSet<RelTag>> {
// fetch directory listing
let key = rel_dir_to_key(spcnode, dbnode);
let buf = self.tline.get(key, lsn)?;
let buf = self.get(key, lsn)?;
let dir = RelDirectory::des(&buf)?;
let rels: HashSet<RelTag> =
@@ -190,7 +169,7 @@ impl<R: Repository> DatadirTimeline<R> {
}
/// Look up given SLRU page version.
pub fn get_slru_page_at_lsn(
fn get_slru_page_at_lsn(
&self,
kind: SlruKind,
segno: u32,
@@ -198,26 +177,21 @@ impl<R: Repository> DatadirTimeline<R> {
lsn: Lsn,
) -> Result<Bytes> {
let key = slru_block_to_key(kind, segno, blknum);
self.tline.get(key, lsn)
self.get(key, lsn)
}
/// Get size of an SLRU segment
pub fn get_slru_segment_size(
&self,
kind: SlruKind,
segno: u32,
lsn: Lsn,
) -> Result<BlockNumber> {
fn get_slru_segment_size(&self, kind: SlruKind, segno: u32, lsn: Lsn) -> Result<BlockNumber> {
let key = slru_segment_size_to_key(kind, segno);
let mut buf = self.tline.get(key, lsn)?;
let mut buf = self.get(key, lsn)?;
Ok(buf.get_u32_le())
}
/// Get size of an SLRU segment
pub fn get_slru_segment_exists(&self, kind: SlruKind, segno: u32, lsn: Lsn) -> Result<bool> {
fn get_slru_segment_exists(&self, kind: SlruKind, segno: u32, lsn: Lsn) -> Result<bool> {
// fetch directory listing
let key = slru_dir_to_key(kind);
let buf = self.tline.get(key, lsn)?;
let buf = self.get(key, lsn)?;
let dir = SlruSegmentDirectory::des(&buf)?;
let exists = dir.segments.get(&segno).is_some();
@@ -231,10 +205,10 @@ impl<R: Repository> DatadirTimeline<R> {
/// so it's not well defined which LSN you get if there were multiple commits
/// "in flight" at that point in time.
///
pub fn find_lsn_for_timestamp(&self, search_timestamp: TimestampTz) -> Result<LsnForTimestamp> {
let gc_cutoff_lsn_guard = self.tline.get_latest_gc_cutoff_lsn();
fn find_lsn_for_timestamp(&self, search_timestamp: TimestampTz) -> Result<LsnForTimestamp> {
let gc_cutoff_lsn_guard = self.get_latest_gc_cutoff_lsn();
let min_lsn = *gc_cutoff_lsn_guard;
let max_lsn = self.tline.get_last_record_lsn();
let max_lsn = self.get_last_record_lsn();
// LSNs are always 8-byte aligned. low/mid/high represent the
// LSN divided by 8.
@@ -325,88 +299,51 @@ impl<R: Repository> DatadirTimeline<R> {
}
/// Get a list of SLRU segments
pub fn list_slru_segments(&self, kind: SlruKind, lsn: Lsn) -> Result<HashSet<u32>> {
fn list_slru_segments(&self, kind: SlruKind, lsn: Lsn) -> Result<HashSet<u32>> {
// fetch directory entry
let key = slru_dir_to_key(kind);
let buf = self.tline.get(key, lsn)?;
let buf = self.get(key, lsn)?;
let dir = SlruSegmentDirectory::des(&buf)?;
Ok(dir.segments)
}
pub fn get_relmap_file(&self, spcnode: Oid, dbnode: Oid, lsn: Lsn) -> Result<Bytes> {
fn get_relmap_file(&self, spcnode: Oid, dbnode: Oid, lsn: Lsn) -> Result<Bytes> {
let key = relmap_file_key(spcnode, dbnode);
let buf = self.tline.get(key, lsn)?;
let buf = self.get(key, lsn)?;
Ok(buf)
}
pub fn list_dbdirs(&self, lsn: Lsn) -> Result<HashMap<(Oid, Oid), bool>> {
fn list_dbdirs(&self, lsn: Lsn) -> Result<HashMap<(Oid, Oid), bool>> {
// fetch directory entry
let buf = self.tline.get(DBDIR_KEY, lsn)?;
let buf = self.get(DBDIR_KEY, lsn)?;
let dir = DbDirectory::des(&buf)?;
Ok(dir.dbdirs)
}
pub fn get_twophase_file(&self, xid: TransactionId, lsn: Lsn) -> Result<Bytes> {
fn get_twophase_file(&self, xid: TransactionId, lsn: Lsn) -> Result<Bytes> {
let key = twophase_file_key(xid);
let buf = self.tline.get(key, lsn)?;
let buf = self.get(key, lsn)?;
Ok(buf)
}
pub fn list_twophase_files(&self, lsn: Lsn) -> Result<HashSet<TransactionId>> {
fn list_twophase_files(&self, lsn: Lsn) -> Result<HashSet<TransactionId>> {
// fetch directory entry
let buf = self.tline.get(TWOPHASEDIR_KEY, lsn)?;
let buf = self.get(TWOPHASEDIR_KEY, lsn)?;
let dir = TwoPhaseDirectory::des(&buf)?;
Ok(dir.xids)
}
pub fn get_control_file(&self, lsn: Lsn) -> Result<Bytes> {
self.tline.get(CONTROLFILE_KEY, lsn)
fn get_control_file(&self, lsn: Lsn) -> Result<Bytes> {
self.get(CONTROLFILE_KEY, lsn)
}
pub fn get_checkpoint(&self, lsn: Lsn) -> Result<Bytes> {
self.tline.get(CHECKPOINT_KEY, lsn)
}
/// Get the LSN of the last ingested WAL record.
///
/// This is just a convenience wrapper that calls through to the underlying
/// repository.
pub fn get_last_record_lsn(&self) -> Lsn {
self.tline.get_last_record_lsn()
}
/// Check that it is valid to request operations with that lsn.
///
/// This is just a convenience wrapper that calls through to the underlying
/// repository.
pub fn check_lsn_is_in_scope(
&self,
lsn: Lsn,
latest_gc_cutoff_lsn: &RwLockReadGuard<Lsn>,
) -> Result<()> {
self.tline.check_lsn_is_in_scope(lsn, latest_gc_cutoff_lsn)
}
/// Retrieve current logical size of the timeline
///
/// NOTE: counted incrementally, includes ancestors,
pub fn get_current_logical_size(&self) -> usize {
let current_logical_size = self.current_logical_size.load(Ordering::Acquire);
match usize::try_from(current_logical_size) {
Ok(sz) => sz,
Err(_) => {
error!(
"current_logical_size is out of range: {}",
current_logical_size
);
0
}
}
fn get_checkpoint(&self, lsn: Lsn) -> Result<Bytes> {
self.get(CHECKPOINT_KEY, lsn)
}
/// Does the same as get_current_logical_size but counted on demand.
@@ -414,16 +351,16 @@ impl<R: Repository> DatadirTimeline<R> {
///
/// Only relation blocks are counted currently. That excludes metadata,
/// SLRUs, twophase files etc.
pub fn get_current_logical_size_non_incremental(&self, lsn: Lsn) -> Result<usize> {
fn get_current_logical_size_non_incremental(&self, lsn: Lsn) -> Result<usize> {
// Fetch list of database dirs and iterate them
let buf = self.tline.get(DBDIR_KEY, lsn)?;
let buf = self.get(DBDIR_KEY, lsn)?;
let dbdir = DbDirectory::des(&buf)?;
let mut total_size: usize = 0;
for (spcnode, dbnode) in dbdir.dbdirs.keys() {
for rel in self.list_rels(*spcnode, *dbnode, lsn)? {
let relsize_key = rel_size_to_key(rel);
let mut buf = self.tline.get(relsize_key, lsn)?;
let mut buf = self.get(relsize_key, lsn)?;
let relsize = buf.get_u32_le();
total_size += relsize as usize;
@@ -444,7 +381,7 @@ impl<R: Repository> DatadirTimeline<R> {
result.add_key(DBDIR_KEY);
// Fetch list of database dirs and iterate them
let buf = self.tline.get(DBDIR_KEY, lsn)?;
let buf = self.get(DBDIR_KEY, lsn)?;
let dbdir = DbDirectory::des(&buf)?;
let mut dbs: Vec<(Oid, Oid)> = dbdir.dbdirs.keys().cloned().collect();
@@ -461,7 +398,7 @@ impl<R: Repository> DatadirTimeline<R> {
rels.sort_unstable();
for rel in rels {
let relsize_key = rel_size_to_key(rel);
let mut buf = self.tline.get(relsize_key, lsn)?;
let mut buf = self.get(relsize_key, lsn)?;
let relsize = buf.get_u32_le();
result.add_range(rel_block_to_key(rel, 0)..rel_block_to_key(rel, relsize));
@@ -477,13 +414,13 @@ impl<R: Repository> DatadirTimeline<R> {
] {
let slrudir_key = slru_dir_to_key(kind);
result.add_key(slrudir_key);
let buf = self.tline.get(slrudir_key, lsn)?;
let buf = self.get(slrudir_key, lsn)?;
let dir = SlruSegmentDirectory::des(&buf)?;
let mut segments: Vec<u32> = dir.segments.iter().cloned().collect();
segments.sort_unstable();
for segno in segments {
let segsize_key = slru_segment_size_to_key(kind, segno);
let mut buf = self.tline.get(segsize_key, lsn)?;
let mut buf = self.get(segsize_key, lsn)?;
let segsize = buf.get_u32_le();
result.add_range(
@@ -495,7 +432,7 @@ impl<R: Repository> DatadirTimeline<R> {
// Then pg_twophase
result.add_key(TWOPHASEDIR_KEY);
let buf = self.tline.get(TWOPHASEDIR_KEY, lsn)?;
let buf = self.get(TWOPHASEDIR_KEY, lsn)?;
let twophase_dir = TwoPhaseDirectory::des(&buf)?;
let mut xids: Vec<TransactionId> = twophase_dir.xids.iter().cloned().collect();
xids.sort_unstable();
@@ -508,32 +445,17 @@ impl<R: Repository> DatadirTimeline<R> {
Ok(result.to_keyspace())
}
pub fn repartition(&self, lsn: Lsn, partition_size: u64) -> Result<(KeyPartitioning, Lsn)> {
let mut partitioning_guard = self.partitioning.lock().unwrap();
if partitioning_guard.1 == Lsn(0)
|| lsn.0 - partitioning_guard.1 .0 > self.repartition_threshold
{
let keyspace = self.collect_keyspace(lsn)?;
let partitioning = keyspace.partition(partition_size);
*partitioning_guard = (partitioning, lsn);
return Ok((partitioning_guard.0.clone(), lsn));
}
Ok((partitioning_guard.0.clone(), partitioning_guard.1))
}
}
/// DatadirModification represents an operation to ingest an atomic set of
/// updates to the repository. It is created by the 'begin_record'
/// function. It is called for each WAL record, so that all the modifications
/// by a one WAL record appear atomic.
pub struct DatadirModification<'a, R: Repository> {
pub struct DatadirModification<'a, T: DatadirTimeline> {
/// The timeline this modification applies to. You can access this to
/// read the state, but note that any pending updates are *not* reflected
/// in the state in 'tline' yet.
pub tline: &'a DatadirTimeline<R>,
lsn: Lsn,
pub tline: &'a T,
// The modifications are not applied directly to the underlying key-value store.
// The put-functions add the modifications here, and they are flushed to the
@@ -543,7 +465,7 @@ pub struct DatadirModification<'a, R: Repository> {
pending_nblocks: isize,
}
impl<'a, R: Repository> DatadirModification<'a, R> {
impl<'a, T: DatadirTimeline> DatadirModification<'a, T> {
/// Initialize a completely new repository.
///
/// This inserts the directory metadata entries that are assumed to
@@ -920,7 +842,7 @@ impl<'a, R: Repository> DatadirModification<'a, R> {
/// retains all the metadata, but data pages are flushed. That's again OK
/// for bulk import, where you are just loading data pages and won't try to
/// modify the same pages twice.
pub fn flush(&mut self) -> Result<()> {
pub fn flush(&mut self, lsn: Lsn) -> Result<()> {
// Unless we have accumulated a decent amount of changes, it's not worth it
// to scan through the pending_updates list.
let pending_nblocks = self.pending_nblocks;
@@ -928,13 +850,13 @@ impl<'a, R: Repository> DatadirModification<'a, R> {
return Ok(());
}
let writer = self.tline.tline.writer();
let writer = self.tline.writer();
// Flush relation and SLRU data blocks, keep metadata.
let mut result: Result<()> = Ok(());
self.pending_updates.retain(|&key, value| {
if result.is_ok() && (is_rel_block_key(key) || is_slru_block_key(key)) {
result = writer.put(key, self.lsn, value);
result = writer.put(key, lsn, value);
false
} else {
true
@@ -943,10 +865,7 @@ impl<'a, R: Repository> DatadirModification<'a, R> {
result?;
if pending_nblocks != 0 {
self.tline.current_logical_size.fetch_add(
pending_nblocks * pg_constants::BLCKSZ as isize,
Ordering::SeqCst,
);
writer.update_current_logical_size(pending_nblocks * pg_constants::BLCKSZ as isize);
self.pending_nblocks = 0;
}
@@ -956,26 +875,25 @@ impl<'a, R: Repository> DatadirModification<'a, R> {
///
/// Finish this atomic update, writing all the updated keys to the
/// underlying timeline.
/// All the modifications in this atomic update are stamped by the specified LSN.
///
pub fn commit(self) -> Result<()> {
let writer = self.tline.tline.writer();
pub fn commit(&mut self, lsn: Lsn) -> Result<()> {
let writer = self.tline.writer();
let pending_nblocks = self.pending_nblocks;
self.pending_nblocks = 0;
for (key, value) in self.pending_updates {
writer.put(key, self.lsn, &value)?;
for (key, value) in self.pending_updates.drain() {
writer.put(key, lsn, &value)?;
}
for key_range in self.pending_deletions {
writer.delete(key_range.clone(), self.lsn)?;
for key_range in self.pending_deletions.drain(..) {
writer.delete(key_range, lsn)?;
}
writer.finish_write(self.lsn);
writer.finish_write(lsn);
if pending_nblocks != 0 {
self.tline.current_logical_size.fetch_add(
pending_nblocks * pg_constants::BLCKSZ as isize,
Ordering::SeqCst,
);
writer.update_current_logical_size(pending_nblocks * pg_constants::BLCKSZ as isize);
}
Ok(())
@@ -1002,7 +920,7 @@ impl<'a, R: Repository> DatadirModification<'a, R> {
}
} else {
let last_lsn = self.tline.get_last_record_lsn();
self.tline.tline.get(key, last_lsn)
self.tline.get(key, last_lsn)
}
}
@@ -1404,13 +1322,12 @@ fn is_slru_block_key(key: Key) -> bool {
pub fn create_test_timeline<R: Repository>(
repo: R,
timeline_id: utils::zid::ZTimelineId,
) -> Result<Arc<crate::DatadirTimeline<R>>> {
) -> Result<std::sync::Arc<R::Timeline>> {
let tline = repo.create_empty_timeline(timeline_id, Lsn(8))?;
let tline = DatadirTimeline::new(tline, 256 * 1024);
let mut m = tline.begin_modification(Lsn(8));
let mut m = tline.begin_modification();
m.init_empty()?;
m.commit()?;
Ok(Arc::new(tline))
m.commit(Lsn(8))?;
Ok(tline)
}
#[allow(clippy::bool_assert_comparison)]
@@ -1483,7 +1400,7 @@ mod tests {
.contains(&TESTREL_A));
// Run checkpoint and garbage collection and check that it's still not visible
newtline.tline.checkpoint(CheckpointConfig::Forced)?;
newtline.checkpoint(CheckpointConfig::Forced)?;
repo.gc_iteration(Some(NEW_TIMELINE_ID), 0, true)?;
assert!(!newtline

View File

@@ -185,7 +185,7 @@ impl Value {
/// A repository corresponds to one .neon directory. One repository holds multiple
/// timelines, forked off from the same initial call to 'initdb'.
pub trait Repository: Send + Sync {
type Timeline: Timeline;
type Timeline: crate::DatadirTimeline;
/// Updates timeline based on the `TimelineSyncStatusUpdate`, received from the remote storage synchronization.
/// See [`crate::remote_storage`] for more details about the synchronization.
@@ -382,6 +382,11 @@ pub trait Timeline: Send + Sync {
lsn: Lsn,
latest_gc_cutoff_lsn: &RwLockReadGuard<Lsn>,
) -> Result<()>;
/// Get the physical size of the timeline at the latest LSN
fn get_physical_size(&self) -> u64;
/// Get the physical size of the timeline at the latest LSN non incrementally
fn get_physical_size_non_incremental(&self) -> Result<u64>;
}
/// Various functions to mutate the timeline.
@@ -405,6 +410,8 @@ pub trait TimelineWriter<'a> {
/// the 'lsn' or anything older. The previous last record LSN is stored alongside
/// the latest and can be read.
fn finish_write(&self, lsn: Lsn);
fn update_current_logical_size(&self, delta: isize);
}
#[cfg(test)]

View File

@@ -176,7 +176,6 @@ use crate::{
layered_repository::{
ephemeral_file::is_ephemeral_file,
metadata::{metadata_path, TimelineMetadata, METADATA_FILE_NAME},
LayeredRepository,
},
storage_sync::{self, index::RemoteIndex},
tenant_mgr::attach_downloaded_tenants,
@@ -1257,7 +1256,13 @@ async fn update_local_metadata(
timeline_id,
} = sync_id;
tokio::task::spawn_blocking(move || {
LayeredRepository::save_metadata(conf, timeline_id, tenant_id, &cloned_metadata, true)
crate::layered_repository::save_metadata(
conf,
timeline_id,
tenant_id,
&cloned_metadata,
true,
)
})
.await
.with_context(|| {

View File

@@ -3,7 +3,6 @@
use crate::config::PageServerConf;
use crate::layered_repository::{load_metadata, LayeredRepository};
use crate::pgdatadir_mapping::DatadirTimeline;
use crate::repository::Repository;
use crate::storage_sync::index::{RemoteIndex, RemoteTimelineIndex};
use crate::storage_sync::{self, LocalTimelineInitStatus, SyncStartupData};
@@ -12,7 +11,7 @@ use crate::thread_mgr::ThreadKind;
use crate::timelines::CreateRepo;
use crate::walredo::PostgresRedoManager;
use crate::{thread_mgr, timelines, walreceiver};
use crate::{DatadirTimelineImpl, RepositoryImpl};
use crate::{RepositoryImpl, TimelineImpl};
use anyhow::Context;
use serde::{Deserialize, Serialize};
use serde_with::{serde_as, DisplayFromStr};
@@ -101,7 +100,7 @@ struct Tenant {
///
/// Local timelines have more metadata that's loaded into memory,
/// that is located in the `repo.timelines` field, [`crate::layered_repository::LayeredTimelineEntry`].
local_timelines: HashMap<ZTimelineId, Arc<DatadirTimelineImpl>>,
local_timelines: HashMap<ZTimelineId, Arc<<RepositoryImpl as Repository>::Timeline>>,
}
#[derive(Debug, Serialize, Deserialize, Clone, Copy, PartialEq, Eq)]
@@ -178,7 +177,7 @@ pub enum LocalTimelineUpdate {
},
Attach {
id: ZTenantTimelineId,
datadir: Arc<DatadirTimelineImpl>,
datadir: Arc<<RepositoryImpl as Repository>::Timeline>,
},
}
@@ -382,7 +381,7 @@ pub fn get_repository_for_tenant(tenant_id: ZTenantId) -> anyhow::Result<Arc<Rep
pub fn get_local_timeline_with_load(
tenant_id: ZTenantId,
timeline_id: ZTimelineId,
) -> anyhow::Result<Arc<DatadirTimelineImpl>> {
) -> anyhow::Result<Arc<TimelineImpl>> {
let mut m = tenants_state::write_tenants();
let tenant = m
.get_mut(&tenant_id)
@@ -489,23 +488,18 @@ pub fn detach_tenant(conf: &'static PageServerConf, tenant_id: ZTenantId) -> any
fn load_local_timeline(
repo: &RepositoryImpl,
timeline_id: ZTimelineId,
) -> anyhow::Result<Arc<DatadirTimeline<LayeredRepository>>> {
) -> anyhow::Result<Arc<TimelineImpl>> {
let inmem_timeline = repo.get_timeline_load(timeline_id).with_context(|| {
format!("Inmem timeline {timeline_id} not found in tenant's repository")
})?;
let repartition_distance = repo.get_checkpoint_distance() / 10;
let page_tline = Arc::new(DatadirTimelineImpl::new(
inmem_timeline,
repartition_distance,
));
page_tline.init_logical_size()?;
inmem_timeline.init_logical_size()?;
tenants_state::try_send_timeline_update(LocalTimelineUpdate::Attach {
id: ZTenantTimelineId::new(repo.tenant_id(), timeline_id),
datadir: Arc::clone(&page_tline),
datadir: Arc::clone(&inmem_timeline),
});
Ok(page_tline)
Ok(inmem_timeline)
}
#[serde_as]

View File

@@ -120,6 +120,10 @@ pub fn init_tenant_task_pool() -> anyhow::Result<()> {
let runtime = tokio::runtime::Builder::new_multi_thread()
.thread_name("tenant-task-worker")
.enable_all()
.on_thread_start(|| {
thread_mgr::register(ThreadKind::TenantTaskWorker, "tenant-task-worker")
})
.on_thread_stop(thread_mgr::deregister)
.build()?;
let (gc_send, mut gc_recv) = mpsc::channel::<ZTenantId>(100);

View File

@@ -97,6 +97,9 @@ pub enum ThreadKind {
// Thread that schedules new compaction and gc jobs
TenantTaskManager,
// Worker thread for tenant tasks thread pool
TenantTaskWorker,
// Thread that flushes frozen in-memory layers to disk
LayerFlushThread,
@@ -105,18 +108,20 @@ pub enum ThreadKind {
StorageSync,
}
#[derive(Default)]
struct MutableThreadState {
/// Tenant and timeline that this thread is associated with.
tenant_id: Option<ZTenantId>,
timeline_id: Option<ZTimelineId>,
/// Handle for waiting for the thread to exit. It can be None, if the
/// the thread has already exited.
/// the thread has already exited. OR if this thread is managed externally
/// and was not spawned through thread_mgr.rs::spawn function.
join_handle: Option<JoinHandle<()>>,
}
struct PageServerThread {
_thread_id: u64,
thread_id: u64,
kind: ThreadKind,
@@ -147,7 +152,7 @@ where
let (shutdown_tx, shutdown_rx) = watch::channel(());
let thread_id = NEXT_THREAD_ID.fetch_add(1, Ordering::Relaxed);
let thread = Arc::new(PageServerThread {
_thread_id: thread_id,
thread_id,
kind,
name: name.to_string(),
shutdown_requested: AtomicBool::new(false),
@@ -315,8 +320,10 @@ pub fn shutdown_threads(
drop(thread_mut);
let _ = join_handle.join();
} else {
// The thread had not even fully started yet. Or it was shut down
// concurrently and already exited
// Possibly one of:
// * The thread had not even fully started yet.
// * It was shut down concurrently and already exited
// * Is managed through `register`/`deregister` fns without providing a join handle
}
}
}
@@ -348,3 +355,56 @@ pub fn is_shutdown_requested() -> bool {
}
})
}
/// Needed to register threads that were not spawned through spawn function.
/// For example tokio blocking threads. This function is expected to be used
/// in tandem with `deregister`.
/// NOTE: threads registered through this function cannot be joined
pub fn register(kind: ThreadKind, name: &str) {
CURRENT_THREAD.with(|ct| {
let mut borrowed = ct.borrow_mut();
if borrowed.is_some() {
panic!("thread already registered")
};
let (shutdown_tx, shutdown_rx) = watch::channel(());
let thread_id = NEXT_THREAD_ID.fetch_add(1, Ordering::Relaxed);
let thread = Arc::new(PageServerThread {
thread_id,
kind,
name: name.to_owned(),
shutdown_requested: AtomicBool::new(false),
shutdown_tx,
mutable: Mutex::new(MutableThreadState {
tenant_id: None,
timeline_id: None,
join_handle: None,
}),
});
*borrowed = Some(Arc::clone(&thread));
SHUTDOWN_RX.with(|rx| {
*rx.borrow_mut() = Some(shutdown_rx);
});
THREADS.lock().unwrap().insert(thread_id, thread);
});
}
// Expected to be used in tandem with `register`. See the doc for `register` for more details
pub fn deregister() {
CURRENT_THREAD.with(|ct| {
let mut borrowed = ct.borrow_mut();
let thread = match borrowed.take() {
Some(thread) => thread,
None => panic!("calling deregister on unregistered thread"),
};
SHUTDOWN_RX.with(|rx| {
*rx.borrow_mut() = None;
});
THREADS.lock().unwrap().remove(&thread.thread_id)
});
}

View File

@@ -26,7 +26,7 @@ use crate::{
repository::{LocalTimelineState, Repository},
storage_sync::index::RemoteIndex,
tenant_config::TenantConfOpt,
DatadirTimeline, RepositoryImpl,
DatadirTimeline, RepositoryImpl, TimelineImpl,
};
use crate::{import_datadir, LOG_FILE_NAME};
use crate::{layered_repository::LayeredRepository, walredo::WalRedoManager};
@@ -49,32 +49,41 @@ pub struct LocalTimelineInfo {
#[serde_as(as = "DisplayFromStr")]
pub disk_consistent_lsn: Lsn,
pub current_logical_size: Option<usize>, // is None when timeline is Unloaded
pub current_physical_size: Option<u64>, // is None when timeline is Unloaded
pub current_logical_size_non_incremental: Option<usize>,
pub current_physical_size_non_incremental: Option<u64>,
pub timeline_state: LocalTimelineState,
}
impl LocalTimelineInfo {
pub fn from_loaded_timeline<R: Repository>(
datadir_tline: &DatadirTimeline<R>,
pub fn from_loaded_timeline(
timeline: &TimelineImpl,
include_non_incremental_logical_size: bool,
include_non_incremental_physical_size: bool,
) -> anyhow::Result<Self> {
let last_record_lsn = datadir_tline.tline.get_last_record_lsn();
let last_record_lsn = timeline.get_last_record_lsn();
let info = LocalTimelineInfo {
ancestor_timeline_id: datadir_tline.tline.get_ancestor_timeline_id(),
ancestor_timeline_id: timeline.get_ancestor_timeline_id(),
ancestor_lsn: {
match datadir_tline.tline.get_ancestor_lsn() {
match timeline.get_ancestor_lsn() {
Lsn(0) => None,
lsn @ Lsn(_) => Some(lsn),
}
},
disk_consistent_lsn: datadir_tline.tline.get_disk_consistent_lsn(),
disk_consistent_lsn: timeline.get_disk_consistent_lsn(),
last_record_lsn,
prev_record_lsn: Some(datadir_tline.tline.get_prev_record_lsn()),
latest_gc_cutoff_lsn: *datadir_tline.tline.get_latest_gc_cutoff_lsn(),
prev_record_lsn: Some(timeline.get_prev_record_lsn()),
latest_gc_cutoff_lsn: *timeline.get_latest_gc_cutoff_lsn(),
timeline_state: LocalTimelineState::Loaded,
current_logical_size: Some(datadir_tline.get_current_logical_size()),
current_physical_size: Some(timeline.get_physical_size()),
current_logical_size: Some(timeline.get_current_logical_size()),
current_logical_size_non_incremental: if include_non_incremental_logical_size {
Some(datadir_tline.get_current_logical_size_non_incremental(last_record_lsn)?)
Some(timeline.get_current_logical_size_non_incremental(last_record_lsn)?)
} else {
None
},
current_physical_size_non_incremental: if include_non_incremental_physical_size {
Some(timeline.get_physical_size_non_incremental()?)
} else {
None
},
@@ -97,7 +106,9 @@ impl LocalTimelineInfo {
latest_gc_cutoff_lsn: metadata.latest_gc_cutoff_lsn(),
timeline_state: LocalTimelineState::Unloaded,
current_logical_size: None,
current_physical_size: None,
current_logical_size_non_incremental: None,
current_physical_size_non_incremental: None,
}
}
@@ -106,12 +117,16 @@ impl LocalTimelineInfo {
timeline_id: ZTimelineId,
repo_timeline: &RepositoryTimeline<T>,
include_non_incremental_logical_size: bool,
include_non_incremental_physical_size: bool,
) -> anyhow::Result<Self> {
match repo_timeline {
RepositoryTimeline::Loaded(_) => {
let datadir_tline =
tenant_mgr::get_local_timeline_with_load(tenant_id, timeline_id)?;
Self::from_loaded_timeline(&datadir_tline, include_non_incremental_logical_size)
let timeline = tenant_mgr::get_local_timeline_with_load(tenant_id, timeline_id)?;
Self::from_loaded_timeline(
&*timeline,
include_non_incremental_logical_size,
include_non_incremental_physical_size,
)
}
RepositoryTimeline::Unloaded { metadata } => Ok(Self::from_unloaded_timeline(metadata)),
}
@@ -298,19 +313,18 @@ fn bootstrap_timeline<R: Repository>(
// Initdb lsn will be equal to last_record_lsn which will be set after import.
// Because we know it upfront avoid having an option or dummy zero value by passing it to create_empty_timeline.
let timeline = repo.create_empty_timeline(tli, lsn)?;
let mut page_tline: DatadirTimeline<R> = DatadirTimeline::new(timeline, u64::MAX);
import_datadir::import_timeline_from_postgres_datadir(&pgdata_path, &mut page_tline, lsn)?;
import_datadir::import_timeline_from_postgres_datadir(&pgdata_path, &*timeline, lsn)?;
fail::fail_point!("before-checkpoint-new-timeline", |_| {
bail!("failpoint before-checkpoint-new-timeline");
});
page_tline.tline.checkpoint(CheckpointConfig::Forced)?;
timeline.checkpoint(CheckpointConfig::Forced)?;
info!(
"created root timeline {} timeline.lsn {}",
tli,
page_tline.tline.get_last_record_lsn()
timeline.get_last_record_lsn()
);
// Remove temp dir. We don't need it anymore
@@ -322,6 +336,7 @@ fn bootstrap_timeline<R: Repository>(
pub(crate) fn get_local_timelines(
tenant_id: ZTenantId,
include_non_incremental_logical_size: bool,
include_non_incremental_physical_size: bool,
) -> Result<Vec<(ZTimelineId, LocalTimelineInfo)>> {
let repo = tenant_mgr::get_repository_for_tenant(tenant_id)
.with_context(|| format!("Failed to get repo for tenant {}", tenant_id))?;
@@ -336,6 +351,7 @@ pub(crate) fn get_local_timelines(
timeline_id,
&repository_timeline,
include_non_incremental_logical_size,
include_non_incremental_physical_size,
)?,
))
}
@@ -389,7 +405,7 @@ pub(crate) fn create_timeline(
// load the timeline into memory
let loaded_timeline =
tenant_mgr::get_local_timeline_with_load(tenant_id, new_timeline_id)?;
LocalTimelineInfo::from_loaded_timeline(&loaded_timeline, false)
LocalTimelineInfo::from_loaded_timeline(&*loaded_timeline, false, false)
.context("cannot fill timeline info")?
}
None => {
@@ -397,7 +413,7 @@ pub(crate) fn create_timeline(
// load the timeline into memory
let new_timeline =
tenant_mgr::get_local_timeline_with_load(tenant_id, new_timeline_id)?;
LocalTimelineInfo::from_loaded_timeline(&new_timeline, false)
LocalTimelineInfo::from_loaded_timeline(&*new_timeline, false, false)
.context("cannot fill timeline info")?
}
};

View File

@@ -34,7 +34,6 @@ use std::collections::HashMap;
use crate::pgdatadir_mapping::*;
use crate::reltag::{RelTag, SlruKind};
use crate::repository::Repository;
use crate::walrecord::*;
use postgres_ffi::nonrelfile_utils::mx_offset_to_member_segment;
use postgres_ffi::xlog_utils::*;
@@ -44,8 +43,8 @@ use utils::lsn::Lsn;
static ZERO_PAGE: Bytes = Bytes::from_static(&[0u8; 8192]);
pub struct WalIngest<'a, R: Repository> {
timeline: &'a DatadirTimeline<R>,
pub struct WalIngest<'a, T: DatadirTimeline> {
timeline: &'a T,
checkpoint: CheckPoint,
checkpoint_modified: bool,
@@ -53,8 +52,8 @@ pub struct WalIngest<'a, R: Repository> {
relsize_cache: HashMap<RelTag, BlockNumber>,
}
impl<'a, R: Repository> WalIngest<'a, R> {
pub fn new(timeline: &DatadirTimeline<R>, startpoint: Lsn) -> Result<WalIngest<R>> {
impl<'a, T: DatadirTimeline> WalIngest<'a, T> {
pub fn new(timeline: &T, startpoint: Lsn) -> Result<WalIngest<T>> {
// Fetch the latest checkpoint into memory, so that we can compare with it
// quickly in `ingest_record` and update it when it changes.
let checkpoint_bytes = timeline.get_checkpoint(startpoint)?;
@@ -78,13 +77,13 @@ impl<'a, R: Repository> WalIngest<'a, R> {
///
pub fn ingest_record(
&mut self,
timeline: &DatadirTimeline<R>,
recdata: Bytes,
lsn: Lsn,
modification: &mut DatadirModification<T>,
decoded: &mut DecodedWALRecord,
) -> Result<()> {
let mut modification = timeline.begin_modification(lsn);
decode_wal_record(recdata, decoded).context("failed decoding wal record")?;
let mut decoded = decode_wal_record(recdata).context("failed decoding wal record")?;
let mut buf = decoded.record.clone();
buf.advance(decoded.main_data_offset);
@@ -98,7 +97,7 @@ impl<'a, R: Repository> WalIngest<'a, R> {
if decoded.xl_rmid == pg_constants::RM_HEAP_ID
|| decoded.xl_rmid == pg_constants::RM_HEAP2_ID
{
self.ingest_heapam_record(&mut buf, &mut modification, &mut decoded)?;
self.ingest_heapam_record(&mut buf, modification, decoded)?;
}
// Handle other special record types
if decoded.xl_rmid == pg_constants::RM_SMGR_ID
@@ -106,19 +105,19 @@ impl<'a, R: Repository> WalIngest<'a, R> {
== pg_constants::XLOG_SMGR_CREATE
{
let create = XlSmgrCreate::decode(&mut buf);
self.ingest_xlog_smgr_create(&mut modification, &create)?;
self.ingest_xlog_smgr_create(modification, &create)?;
} else if decoded.xl_rmid == pg_constants::RM_SMGR_ID
&& (decoded.xl_info & pg_constants::XLR_RMGR_INFO_MASK)
== pg_constants::XLOG_SMGR_TRUNCATE
{
let truncate = XlSmgrTruncate::decode(&mut buf);
self.ingest_xlog_smgr_truncate(&mut modification, &truncate)?;
self.ingest_xlog_smgr_truncate(modification, &truncate)?;
} else if decoded.xl_rmid == pg_constants::RM_DBASE_ID {
if (decoded.xl_info & pg_constants::XLR_RMGR_INFO_MASK)
== pg_constants::XLOG_DBASE_CREATE
{
let createdb = XlCreateDatabase::decode(&mut buf);
self.ingest_xlog_dbase_create(&mut modification, &createdb)?;
self.ingest_xlog_dbase_create(modification, &createdb)?;
} else if (decoded.xl_info & pg_constants::XLR_RMGR_INFO_MASK)
== pg_constants::XLOG_DBASE_DROP
{
@@ -137,7 +136,7 @@ impl<'a, R: Repository> WalIngest<'a, R> {
let segno = pageno / pg_constants::SLRU_PAGES_PER_SEGMENT;
let rpageno = pageno % pg_constants::SLRU_PAGES_PER_SEGMENT;
self.put_slru_page_image(
&mut modification,
modification,
SlruKind::Clog,
segno,
rpageno,
@@ -146,7 +145,7 @@ impl<'a, R: Repository> WalIngest<'a, R> {
} else {
assert!(info == pg_constants::CLOG_TRUNCATE);
let xlrec = XlClogTruncate::decode(&mut buf);
self.ingest_clog_truncate_record(&mut modification, &xlrec)?;
self.ingest_clog_truncate_record(modification, &xlrec)?;
}
} else if decoded.xl_rmid == pg_constants::RM_XACT_ID {
let info = decoded.xl_info & pg_constants::XLOG_XACT_OPMASK;
@@ -154,7 +153,7 @@ impl<'a, R: Repository> WalIngest<'a, R> {
let parsed_xact =
XlXactParsedRecord::decode(&mut buf, decoded.xl_xid, decoded.xl_info);
self.ingest_xact_record(
&mut modification,
modification,
&parsed_xact,
info == pg_constants::XLOG_XACT_COMMIT,
)?;
@@ -164,7 +163,7 @@ impl<'a, R: Repository> WalIngest<'a, R> {
let parsed_xact =
XlXactParsedRecord::decode(&mut buf, decoded.xl_xid, decoded.xl_info);
self.ingest_xact_record(
&mut modification,
modification,
&parsed_xact,
info == pg_constants::XLOG_XACT_COMMIT_PREPARED,
)?;
@@ -187,7 +186,7 @@ impl<'a, R: Repository> WalIngest<'a, R> {
let segno = pageno / pg_constants::SLRU_PAGES_PER_SEGMENT;
let rpageno = pageno % pg_constants::SLRU_PAGES_PER_SEGMENT;
self.put_slru_page_image(
&mut modification,
modification,
SlruKind::MultiXactOffsets,
segno,
rpageno,
@@ -198,7 +197,7 @@ impl<'a, R: Repository> WalIngest<'a, R> {
let segno = pageno / pg_constants::SLRU_PAGES_PER_SEGMENT;
let rpageno = pageno % pg_constants::SLRU_PAGES_PER_SEGMENT;
self.put_slru_page_image(
&mut modification,
modification,
SlruKind::MultiXactMembers,
segno,
rpageno,
@@ -206,14 +205,14 @@ impl<'a, R: Repository> WalIngest<'a, R> {
)?;
} else if info == pg_constants::XLOG_MULTIXACT_CREATE_ID {
let xlrec = XlMultiXactCreate::decode(&mut buf);
self.ingest_multixact_create_record(&mut modification, &xlrec)?;
self.ingest_multixact_create_record(modification, &xlrec)?;
} else if info == pg_constants::XLOG_MULTIXACT_TRUNCATE_ID {
let xlrec = XlMultiXactTruncate::decode(&mut buf);
self.ingest_multixact_truncate_record(&mut modification, &xlrec)?;
self.ingest_multixact_truncate_record(modification, &xlrec)?;
}
} else if decoded.xl_rmid == pg_constants::RM_RELMAP_ID {
let xlrec = XlRelmapUpdate::decode(&mut buf);
self.ingest_relmap_page(&mut modification, &xlrec, &decoded)?;
self.ingest_relmap_page(modification, &xlrec, decoded)?;
} else if decoded.xl_rmid == pg_constants::RM_XLOG_ID {
let info = decoded.xl_info & pg_constants::XLR_RMGR_INFO_MASK;
if info == pg_constants::XLOG_NEXTOID {
@@ -248,7 +247,7 @@ impl<'a, R: Repository> WalIngest<'a, R> {
// Iterate through all the blocks that the record modifies, and
// "put" a separate copy of the record for each block.
for blk in decoded.blocks.iter() {
self.ingest_decoded_block(&mut modification, lsn, &decoded, blk)?;
self.ingest_decoded_block(modification, lsn, decoded, blk)?;
}
// If checkpoint data was updated, store the new version in the repository
@@ -261,14 +260,14 @@ impl<'a, R: Repository> WalIngest<'a, R> {
// Now that this record has been fully handled, including updating the
// checkpoint data, let the repository know that it is up-to-date to this LSN
modification.commit()?;
modification.commit(lsn)?;
Ok(())
}
fn ingest_decoded_block(
&mut self,
modification: &mut DatadirModification<R>,
modification: &mut DatadirModification<T>,
lsn: Lsn,
decoded: &DecodedWALRecord,
blk: &DecodedBkpBlock,
@@ -328,7 +327,7 @@ impl<'a, R: Repository> WalIngest<'a, R> {
fn ingest_heapam_record(
&mut self,
buf: &mut Bytes,
modification: &mut DatadirModification<R>,
modification: &mut DatadirModification<T>,
decoded: &mut DecodedWALRecord,
) -> Result<()> {
// Handle VM bit updates that are implicitly part of heap records.
@@ -472,7 +471,7 @@ impl<'a, R: Repository> WalIngest<'a, R> {
/// Subroutine of ingest_record(), to handle an XLOG_DBASE_CREATE record.
fn ingest_xlog_dbase_create(
&mut self,
modification: &mut DatadirModification<R>,
modification: &mut DatadirModification<T>,
rec: &XlCreateDatabase,
) -> Result<()> {
let db_id = rec.db_id;
@@ -539,7 +538,7 @@ impl<'a, R: Repository> WalIngest<'a, R> {
fn ingest_xlog_smgr_create(
&mut self,
modification: &mut DatadirModification<R>,
modification: &mut DatadirModification<T>,
rec: &XlSmgrCreate,
) -> Result<()> {
let rel = RelTag {
@@ -557,7 +556,7 @@ impl<'a, R: Repository> WalIngest<'a, R> {
/// This is the same logic as in PostgreSQL's smgr_redo() function.
fn ingest_xlog_smgr_truncate(
&mut self,
modification: &mut DatadirModification<R>,
modification: &mut DatadirModification<T>,
rec: &XlSmgrTruncate,
) -> Result<()> {
let spcnode = rec.rnode.spcnode;
@@ -622,7 +621,7 @@ impl<'a, R: Repository> WalIngest<'a, R> {
///
fn ingest_xact_record(
&mut self,
modification: &mut DatadirModification<R>,
modification: &mut DatadirModification<T>,
parsed: &XlXactParsedRecord,
is_commit: bool,
) -> Result<()> {
@@ -691,7 +690,7 @@ impl<'a, R: Repository> WalIngest<'a, R> {
fn ingest_clog_truncate_record(
&mut self,
modification: &mut DatadirModification<R>,
modification: &mut DatadirModification<T>,
xlrec: &XlClogTruncate,
) -> Result<()> {
info!(
@@ -749,7 +748,7 @@ impl<'a, R: Repository> WalIngest<'a, R> {
fn ingest_multixact_create_record(
&mut self,
modification: &mut DatadirModification<R>,
modification: &mut DatadirModification<T>,
xlrec: &XlMultiXactCreate,
) -> Result<()> {
// Create WAL record for updating the multixact-offsets page
@@ -828,7 +827,7 @@ impl<'a, R: Repository> WalIngest<'a, R> {
fn ingest_multixact_truncate_record(
&mut self,
modification: &mut DatadirModification<R>,
modification: &mut DatadirModification<T>,
xlrec: &XlMultiXactTruncate,
) -> Result<()> {
self.checkpoint.oldestMulti = xlrec.end_trunc_off;
@@ -862,7 +861,7 @@ impl<'a, R: Repository> WalIngest<'a, R> {
fn ingest_relmap_page(
&mut self,
modification: &mut DatadirModification<R>,
modification: &mut DatadirModification<T>,
xlrec: &XlRelmapUpdate,
decoded: &DecodedWALRecord,
) -> Result<()> {
@@ -878,7 +877,7 @@ impl<'a, R: Repository> WalIngest<'a, R> {
fn put_rel_creation(
&mut self,
modification: &mut DatadirModification<R>,
modification: &mut DatadirModification<T>,
rel: RelTag,
) -> Result<()> {
self.relsize_cache.insert(rel, 0);
@@ -888,7 +887,7 @@ impl<'a, R: Repository> WalIngest<'a, R> {
fn put_rel_page_image(
&mut self,
modification: &mut DatadirModification<R>,
modification: &mut DatadirModification<T>,
rel: RelTag,
blknum: BlockNumber,
img: Bytes,
@@ -900,7 +899,7 @@ impl<'a, R: Repository> WalIngest<'a, R> {
fn put_rel_wal_record(
&mut self,
modification: &mut DatadirModification<R>,
modification: &mut DatadirModification<T>,
rel: RelTag,
blknum: BlockNumber,
rec: ZenithWalRecord,
@@ -912,7 +911,7 @@ impl<'a, R: Repository> WalIngest<'a, R> {
fn put_rel_truncation(
&mut self,
modification: &mut DatadirModification<R>,
modification: &mut DatadirModification<T>,
rel: RelTag,
nblocks: BlockNumber,
) -> Result<()> {
@@ -923,7 +922,7 @@ impl<'a, R: Repository> WalIngest<'a, R> {
fn put_rel_drop(
&mut self,
modification: &mut DatadirModification<R>,
modification: &mut DatadirModification<T>,
rel: RelTag,
) -> Result<()> {
modification.put_rel_drop(rel)?;
@@ -948,7 +947,7 @@ impl<'a, R: Repository> WalIngest<'a, R> {
fn handle_rel_extend(
&mut self,
modification: &mut DatadirModification<R>,
modification: &mut DatadirModification<T>,
rel: RelTag,
blknum: BlockNumber,
) -> Result<()> {
@@ -986,7 +985,7 @@ impl<'a, R: Repository> WalIngest<'a, R> {
fn put_slru_page_image(
&mut self,
modification: &mut DatadirModification<R>,
modification: &mut DatadirModification<T>,
kind: SlruKind,
segno: u32,
blknum: BlockNumber,
@@ -999,7 +998,7 @@ impl<'a, R: Repository> WalIngest<'a, R> {
fn handle_slru_extend(
&mut self,
modification: &mut DatadirModification<R>,
modification: &mut DatadirModification<T>,
kind: SlruKind,
segno: u32,
blknum: BlockNumber,
@@ -1052,6 +1051,7 @@ mod tests {
use super::*;
use crate::pgdatadir_mapping::create_test_timeline;
use crate::repository::repo_harness::*;
use crate::repository::Timeline;
use postgres_ffi::pg_constants;
/// Arbitrary relation tag, for testing.
@@ -1062,17 +1062,17 @@ mod tests {
forknum: 0,
};
fn assert_current_logical_size<R: Repository>(_timeline: &DatadirTimeline<R>, _lsn: Lsn) {
fn assert_current_logical_size<T: Timeline>(_timeline: &T, _lsn: Lsn) {
// TODO
}
static ZERO_CHECKPOINT: Bytes = Bytes::from_static(&[0u8; SIZEOF_CHECKPOINT]);
fn init_walingest_test<R: Repository>(tline: &DatadirTimeline<R>) -> Result<WalIngest<R>> {
let mut m = tline.begin_modification(Lsn(0x10));
fn init_walingest_test<T: DatadirTimeline>(tline: &T) -> Result<WalIngest<T>> {
let mut m = tline.begin_modification();
m.put_checkpoint(ZERO_CHECKPOINT.clone())?;
m.put_relmap_file(0, 111, Bytes::from(""))?; // dummy relmapper file
m.commit()?;
m.commit(Lsn(0x10))?;
let walingest = WalIngest::new(tline, Lsn(0x10))?;
Ok(walingest)
@@ -1082,23 +1082,23 @@ mod tests {
fn test_relsize() -> Result<()> {
let repo = RepoHarness::create("test_relsize")?.load();
let tline = create_test_timeline(repo, TIMELINE_ID)?;
let mut walingest = init_walingest_test(&tline)?;
let mut walingest = init_walingest_test(&*tline)?;
let mut m = tline.begin_modification(Lsn(0x20));
let mut m = tline.begin_modification();
walingest.put_rel_creation(&mut m, TESTREL_A)?;
walingest.put_rel_page_image(&mut m, TESTREL_A, 0, TEST_IMG("foo blk 0 at 2"))?;
m.commit()?;
let mut m = tline.begin_modification(Lsn(0x30));
m.commit(Lsn(0x20))?;
let mut m = tline.begin_modification();
walingest.put_rel_page_image(&mut m, TESTREL_A, 0, TEST_IMG("foo blk 0 at 3"))?;
m.commit()?;
let mut m = tline.begin_modification(Lsn(0x40));
m.commit(Lsn(0x30))?;
let mut m = tline.begin_modification();
walingest.put_rel_page_image(&mut m, TESTREL_A, 1, TEST_IMG("foo blk 1 at 4"))?;
m.commit()?;
let mut m = tline.begin_modification(Lsn(0x50));
m.commit(Lsn(0x40))?;
let mut m = tline.begin_modification();
walingest.put_rel_page_image(&mut m, TESTREL_A, 2, TEST_IMG("foo blk 2 at 5"))?;
m.commit()?;
m.commit(Lsn(0x50))?;
assert_current_logical_size(&tline, Lsn(0x50));
assert_current_logical_size(&*tline, Lsn(0x50));
// The relation was created at LSN 2, not visible at LSN 1 yet.
assert_eq!(tline.get_rel_exists(TESTREL_A, Lsn(0x10))?, false);
@@ -1142,10 +1142,10 @@ mod tests {
);
// Truncate last block
let mut m = tline.begin_modification(Lsn(0x60));
let mut m = tline.begin_modification();
walingest.put_rel_truncation(&mut m, TESTREL_A, 2)?;
m.commit()?;
assert_current_logical_size(&tline, Lsn(0x60));
m.commit(Lsn(0x60))?;
assert_current_logical_size(&*tline, Lsn(0x60));
// Check reported size and contents after truncation
assert_eq!(tline.get_rel_size(TESTREL_A, Lsn(0x60))?, 2);
@@ -1166,15 +1166,15 @@ mod tests {
);
// Truncate to zero length
let mut m = tline.begin_modification(Lsn(0x68));
let mut m = tline.begin_modification();
walingest.put_rel_truncation(&mut m, TESTREL_A, 0)?;
m.commit()?;
m.commit(Lsn(0x68))?;
assert_eq!(tline.get_rel_size(TESTREL_A, Lsn(0x68))?, 0);
// Extend from 0 to 2 blocks, leaving a gap
let mut m = tline.begin_modification(Lsn(0x70));
let mut m = tline.begin_modification();
walingest.put_rel_page_image(&mut m, TESTREL_A, 1, TEST_IMG("foo blk 1"))?;
m.commit()?;
m.commit(Lsn(0x70))?;
assert_eq!(tline.get_rel_size(TESTREL_A, Lsn(0x70))?, 2);
assert_eq!(
tline.get_rel_page_at_lsn(TESTREL_A, 0, Lsn(0x70))?,
@@ -1186,9 +1186,9 @@ mod tests {
);
// Extend a lot more, leaving a big gap that spans across segments
let mut m = tline.begin_modification(Lsn(0x80));
let mut m = tline.begin_modification();
walingest.put_rel_page_image(&mut m, TESTREL_A, 1500, TEST_IMG("foo blk 1500"))?;
m.commit()?;
m.commit(Lsn(0x80))?;
assert_eq!(tline.get_rel_size(TESTREL_A, Lsn(0x80))?, 1501);
for blk in 2..1500 {
assert_eq!(
@@ -1210,20 +1210,20 @@ mod tests {
fn test_drop_extend() -> Result<()> {
let repo = RepoHarness::create("test_drop_extend")?.load();
let tline = create_test_timeline(repo, TIMELINE_ID)?;
let mut walingest = init_walingest_test(&tline)?;
let mut walingest = init_walingest_test(&*tline)?;
let mut m = tline.begin_modification(Lsn(0x20));
let mut m = tline.begin_modification();
walingest.put_rel_page_image(&mut m, TESTREL_A, 0, TEST_IMG("foo blk 0 at 2"))?;
m.commit()?;
m.commit(Lsn(0x20))?;
// Check that rel exists and size is correct
assert_eq!(tline.get_rel_exists(TESTREL_A, Lsn(0x20))?, true);
assert_eq!(tline.get_rel_size(TESTREL_A, Lsn(0x20))?, 1);
// Drop rel
let mut m = tline.begin_modification(Lsn(0x30));
let mut m = tline.begin_modification();
walingest.put_rel_drop(&mut m, TESTREL_A)?;
m.commit()?;
m.commit(Lsn(0x30))?;
// Check that rel is not visible anymore
assert_eq!(tline.get_rel_exists(TESTREL_A, Lsn(0x30))?, false);
@@ -1232,9 +1232,9 @@ mod tests {
//assert!(tline.get_rel_size(TESTREL_A, Lsn(0x30))?.is_none());
// Re-create it
let mut m = tline.begin_modification(Lsn(0x40));
let mut m = tline.begin_modification();
walingest.put_rel_page_image(&mut m, TESTREL_A, 0, TEST_IMG("foo blk 0 at 4"))?;
m.commit()?;
m.commit(Lsn(0x40))?;
// Check that rel exists and size is correct
assert_eq!(tline.get_rel_exists(TESTREL_A, Lsn(0x40))?, true);
@@ -1250,16 +1250,16 @@ mod tests {
fn test_truncate_extend() -> Result<()> {
let repo = RepoHarness::create("test_truncate_extend")?.load();
let tline = create_test_timeline(repo, TIMELINE_ID)?;
let mut walingest = init_walingest_test(&tline)?;
let mut walingest = init_walingest_test(&*tline)?;
// Create a 20 MB relation (the size is arbitrary)
let relsize = 20 * 1024 * 1024 / 8192;
let mut m = tline.begin_modification(Lsn(0x20));
let mut m = tline.begin_modification();
for blkno in 0..relsize {
let data = format!("foo blk {} at {}", blkno, Lsn(0x20));
walingest.put_rel_page_image(&mut m, TESTREL_A, blkno, TEST_IMG(&data))?;
}
m.commit()?;
m.commit(Lsn(0x20))?;
// The relation was created at LSN 20, not visible at LSN 1 yet.
assert_eq!(tline.get_rel_exists(TESTREL_A, Lsn(0x10))?, false);
@@ -1280,9 +1280,9 @@ mod tests {
// Truncate relation so that second segment was dropped
// - only leave one page
let mut m = tline.begin_modification(Lsn(0x60));
let mut m = tline.begin_modification();
walingest.put_rel_truncation(&mut m, TESTREL_A, 1)?;
m.commit()?;
m.commit(Lsn(0x60))?;
// Check reported size and contents after truncation
assert_eq!(tline.get_rel_size(TESTREL_A, Lsn(0x60))?, 1);
@@ -1310,12 +1310,12 @@ mod tests {
// Extend relation again.
// Add enough blocks to create second segment
let lsn = Lsn(0x80);
let mut m = tline.begin_modification(lsn);
let mut m = tline.begin_modification();
for blkno in 0..relsize {
let data = format!("foo blk {} at {}", blkno, lsn);
walingest.put_rel_page_image(&mut m, TESTREL_A, blkno, TEST_IMG(&data))?;
}
m.commit()?;
m.commit(lsn)?;
assert_eq!(tline.get_rel_exists(TESTREL_A, Lsn(0x80))?, true);
assert_eq!(tline.get_rel_size(TESTREL_A, Lsn(0x80))?, relsize);
@@ -1338,18 +1338,18 @@ mod tests {
fn test_large_rel() -> Result<()> {
let repo = RepoHarness::create("test_large_rel")?.load();
let tline = create_test_timeline(repo, TIMELINE_ID)?;
let mut walingest = init_walingest_test(&tline)?;
let mut walingest = init_walingest_test(&*tline)?;
let mut lsn = 0x10;
for blknum in 0..pg_constants::RELSEG_SIZE + 1 {
lsn += 0x10;
let mut m = tline.begin_modification(Lsn(lsn));
let mut m = tline.begin_modification();
let img = TEST_IMG(&format!("foo blk {} at {}", blknum, Lsn(lsn)));
walingest.put_rel_page_image(&mut m, TESTREL_A, blknum as BlockNumber, img)?;
m.commit()?;
m.commit(Lsn(lsn))?;
}
assert_current_logical_size(&tline, Lsn(lsn));
assert_current_logical_size(&*tline, Lsn(lsn));
assert_eq!(
tline.get_rel_size(TESTREL_A, Lsn(lsn))?,
@@ -1358,34 +1358,34 @@ mod tests {
// Truncate one block
lsn += 0x10;
let mut m = tline.begin_modification(Lsn(lsn));
let mut m = tline.begin_modification();
walingest.put_rel_truncation(&mut m, TESTREL_A, pg_constants::RELSEG_SIZE)?;
m.commit()?;
m.commit(Lsn(lsn))?;
assert_eq!(
tline.get_rel_size(TESTREL_A, Lsn(lsn))?,
pg_constants::RELSEG_SIZE
);
assert_current_logical_size(&tline, Lsn(lsn));
assert_current_logical_size(&*tline, Lsn(lsn));
// Truncate another block
lsn += 0x10;
let mut m = tline.begin_modification(Lsn(lsn));
let mut m = tline.begin_modification();
walingest.put_rel_truncation(&mut m, TESTREL_A, pg_constants::RELSEG_SIZE - 1)?;
m.commit()?;
m.commit(Lsn(lsn))?;
assert_eq!(
tline.get_rel_size(TESTREL_A, Lsn(lsn))?,
pg_constants::RELSEG_SIZE - 1
);
assert_current_logical_size(&tline, Lsn(lsn));
assert_current_logical_size(&*tline, Lsn(lsn));
// Truncate to 1500, and then truncate all the way down to 0, one block at a time
// This tests the behavior at segment boundaries
let mut size: i32 = 3000;
while size >= 0 {
lsn += 0x10;
let mut m = tline.begin_modification(Lsn(lsn));
let mut m = tline.begin_modification();
walingest.put_rel_truncation(&mut m, TESTREL_A, size as BlockNumber)?;
m.commit()?;
m.commit(Lsn(lsn))?;
assert_eq!(
tline.get_rel_size(TESTREL_A, Lsn(lsn))?,
size as BlockNumber
@@ -1393,7 +1393,7 @@ mod tests {
size -= 1;
}
assert_current_logical_size(&tline, Lsn(lsn));
assert_current_logical_size(&*tline, Lsn(lsn));
Ok(())
}

View File

@@ -25,7 +25,8 @@ use etcd_broker::{
use tokio::select;
use tracing::*;
use crate::DatadirTimelineImpl;
use crate::repository::{Repository, Timeline};
use crate::{RepositoryImpl, TimelineImpl};
use utils::{
lsn::Lsn,
pq_proto::ReplicationFeedback,
@@ -39,7 +40,7 @@ pub(super) fn spawn_connection_manager_task(
id: ZTenantTimelineId,
broker_loop_prefix: String,
mut client: Client,
local_timeline: Arc<DatadirTimelineImpl>,
local_timeline: Arc<TimelineImpl>,
wal_connect_timeout: Duration,
lagging_wal_timeout: Duration,
max_lsn_wal_lag: NonZeroU64,
@@ -229,8 +230,8 @@ async fn subscribe_for_timeline_updates(
}
}
const DEFAULT_BASE_BACKOFF_SECONDS: f64 = 2.0;
const DEFAULT_MAX_BACKOFF_SECONDS: f64 = 60.0;
const DEFAULT_BASE_BACKOFF_SECONDS: f64 = 0.1;
const DEFAULT_MAX_BACKOFF_SECONDS: f64 = 3.0;
async fn exponential_backoff(n: u32, base: f64, max_seconds: f64) {
if n == 0 {
@@ -245,7 +246,7 @@ async fn exponential_backoff(n: u32, base: f64, max_seconds: f64) {
struct WalreceiverState {
id: ZTenantTimelineId,
/// Use pageserver data about the timeline to filter out some of the safekeepers.
local_timeline: Arc<DatadirTimelineImpl>,
local_timeline: Arc<TimelineImpl>,
/// The timeout on the connection to safekeeper for WAL streaming.
wal_connect_timeout: Duration,
/// The timeout to use to determine when the current connection is "stale" and reconnect to the other one.
@@ -283,7 +284,7 @@ struct EtcdSkTimeline {
impl WalreceiverState {
fn new(
id: ZTenantTimelineId,
local_timeline: Arc<DatadirTimelineImpl>,
local_timeline: Arc<<RepositoryImpl as Repository>::Timeline>,
wal_connect_timeout: Duration,
lagging_wal_timeout: Duration,
max_lsn_wal_lag: NonZeroU64,
@@ -1203,13 +1204,10 @@ mod tests {
tenant_id: harness.tenant_id,
timeline_id: TIMELINE_ID,
},
local_timeline: Arc::new(DatadirTimelineImpl::new(
harness
.load()
.create_empty_timeline(TIMELINE_ID, Lsn(0))
.expect("Failed to create an empty timeline for dummy wal connection manager"),
10_000,
)),
local_timeline: harness
.load()
.create_empty_timeline(TIMELINE_ID, Lsn(0))
.expect("Failed to create an empty timeline for dummy wal connection manager"),
wal_connect_timeout: Duration::from_secs(1),
lagging_wal_timeout: Duration::from_secs(1),
max_lsn_wal_lag: NonZeroU64::new(1).unwrap(),

View File

@@ -9,20 +9,22 @@ use std::{
use anyhow::{bail, ensure, Context};
use bytes::BytesMut;
use fail::fail_point;
use futures::StreamExt;
use postgres::{SimpleQueryMessage, SimpleQueryRow};
use postgres_protocol::message::backend::ReplicationMessage;
use postgres_types::PgLsn;
use tokio::{pin, select, sync::watch, time};
use tokio_postgres::{replication::ReplicationStream, Client};
use tokio_stream::StreamExt;
use tracing::{debug, error, info, info_span, trace, warn, Instrument};
use super::TaskEvent;
use crate::{
http::models::WalReceiverEntry,
pgdatadir_mapping::DatadirTimeline,
repository::{Repository, Timeline},
tenant_mgr,
walingest::WalIngest,
walrecord::DecodedWALRecord,
};
use postgres_ffi::waldecoder::WalStreamDecoder;
use utils::{lsn::Lsn, pq_proto::ReplicationFeedback, zid::ZTenantTimelineId};
@@ -150,19 +152,25 @@ pub async fn handle_walreceiver_connection(
waldecoder.feed_bytes(data);
while let Some((lsn, recdata)) = waldecoder.poll_decode()? {
let _enter = info_span!("processing record", lsn = %lsn).entered();
{
let mut decoded = DecodedWALRecord::default();
let mut modification = timeline.begin_modification();
while let Some((lsn, recdata)) = waldecoder.poll_decode()? {
// let _enter = info_span!("processing record", lsn = %lsn).entered();
// It is important to deal with the aligned records as lsn in getPage@LSN is
// aligned and can be several bytes bigger. Without this alignment we are
// at risk of hitting a deadlock.
ensure!(lsn.is_aligned());
// It is important to deal with the aligned records as lsn in getPage@LSN is
// aligned and can be several bytes bigger. Without this alignment we are
// at risk of hitting a deadlock.
ensure!(lsn.is_aligned());
walingest.ingest_record(&timeline, recdata, lsn)?;
walingest
.ingest_record(recdata, lsn, &mut modification, &mut decoded)
.context("could not ingest record at {lsn}")?;
fail_point!("walreceiver-after-ingest");
fail_point!("walreceiver-after-ingest");
last_rec_lsn = lsn;
last_rec_lsn = lsn;
}
}
if !caught_up && endlsn >= end_of_wal {
@@ -170,7 +178,7 @@ pub async fn handle_walreceiver_connection(
caught_up = true;
}
let timeline_to_check = Arc::clone(&timeline.tline);
let timeline_to_check = Arc::clone(&timeline);
tokio::task::spawn_blocking(move || timeline_to_check.check_checkpoint_distance())
.await
.with_context(|| {
@@ -218,7 +226,7 @@ pub async fn handle_walreceiver_connection(
// The last LSN we processed. It is not guaranteed to survive pageserver crash.
let write_lsn = u64::from(last_lsn);
// `disk_consistent_lsn` is the LSN at which page server guarantees local persistence of all received data
let flush_lsn = u64::from(timeline.tline.get_disk_consistent_lsn());
let flush_lsn = u64::from(timeline.get_disk_consistent_lsn());
// The last LSN that is synced to remote storage and is guaranteed to survive pageserver crash
// Used by safekeepers to remove WAL preceding `remote_consistent_lsn`.
let apply_lsn = u64::from(timeline_remote_consistent_lsn);

View File

@@ -96,6 +96,7 @@ impl DecodedBkpBlock {
}
}
#[derive(Default)]
pub struct DecodedWALRecord {
pub xl_xid: TransactionId,
pub xl_info: u8,
@@ -505,7 +506,17 @@ impl XlMultiXactTruncate {
// block data
// ...
// main data
pub fn decode_wal_record(record: Bytes) -> Result<DecodedWALRecord, DeserializeError> {
//
//
// For performance reasons, the caller provides the DecodedWALRecord struct and the function just fills it in.
// It would be more natural for this function to return a DecodedWALRecord as return value,
// but reusing the caller-supplied struct avoids an allocation.
// This code is in the hot path for digesting incoming WAL, and is very performance sensitive.
//
pub fn decode_wal_record(
record: Bytes,
decoded: &mut DecodedWALRecord,
) -> Result<(), DeserializeError> {
let mut rnode_spcnode: u32 = 0;
let mut rnode_dbnode: u32 = 0;
let mut rnode_relnode: u32 = 0;
@@ -534,7 +545,7 @@ pub fn decode_wal_record(record: Bytes) -> Result<DecodedWALRecord, DeserializeE
let mut blocks_total_len: u32 = 0;
let mut main_data_len = 0;
let mut datatotal: u32 = 0;
let mut blocks: Vec<DecodedBkpBlock> = Vec::new();
decoded.blocks.clear();
// 2. Decode the headers.
// XLogRecordBlockHeaders if any,
@@ -713,7 +724,7 @@ pub fn decode_wal_record(record: Bytes) -> Result<DecodedWALRecord, DeserializeE
blk.blkno
);
blocks.push(blk);
decoded.blocks.push(blk);
}
_ => {
@@ -724,7 +735,7 @@ pub fn decode_wal_record(record: Bytes) -> Result<DecodedWALRecord, DeserializeE
// 3. Decode blocks.
let mut ptr = record.len() - buf.remaining();
for blk in blocks.iter_mut() {
for blk in decoded.blocks.iter_mut() {
if blk.has_image {
blk.bimg_offset = ptr as u32;
ptr += blk.bimg_len as usize;
@@ -744,14 +755,13 @@ pub fn decode_wal_record(record: Bytes) -> Result<DecodedWALRecord, DeserializeE
assert_eq!(buf.remaining(), main_data_len as usize);
}
Ok(DecodedWALRecord {
xl_xid: xlogrec.xl_xid,
xl_info: xlogrec.xl_info,
xl_rmid: xlogrec.xl_rmid,
record,
blocks,
main_data_offset,
})
decoded.xl_xid = xlogrec.xl_xid;
decoded.xl_info = xlogrec.xl_info;
decoded.xl_rmid = xlogrec.xl_rmid;
decoded.record = record;
decoded.main_data_offset = main_data_offset;
Ok(())
}
///

View File

@@ -1,11 +1,14 @@
//! Client authentication mechanisms.
pub mod backend;
pub use backend::DatabaseInfo;
pub use backend::{BackendType, DatabaseInfo};
mod credentials;
pub use credentials::ClientCredentials;
mod password_hack;
use password_hack::PasswordHackPayload;
mod flow;
pub use flow::*;
@@ -29,9 +32,8 @@ pub enum AuthErrorImpl {
#[error(transparent)]
Sasl(#[from] crate::sasl::Error),
/// For passwords that couldn't be processed by [`backend::legacy_console::parse_password`].
#[error("Malformed password message")]
MalformedPassword,
#[error("Malformed password message: {0}")]
MalformedPassword(&'static str),
/// Errors produced by [`crate::stream::PqStream`].
#[error(transparent)]
@@ -76,7 +78,7 @@ impl UserFacingError for AuthError {
Console(e) => e.to_string_client(),
GetAuthInfo(e) => e.to_string_client(),
Sasl(e) => e.to_string_client(),
MalformedPassword => self.to_string(),
MalformedPassword(_) => self.to_string(),
_ => "Internal error".to_string(),
}
}

View File

@@ -1,16 +1,14 @@
mod legacy_console;
mod link;
mod postgres;
pub mod console;
mod legacy_console;
pub use legacy_console::{AuthError, AuthErrorImpl};
use super::ClientCredentials;
use crate::{
compute,
config::{AuthBackendType, ProxyConfig},
mgmt,
auth::{self, AuthFlow, ClientCredentials},
compute, config, mgmt,
stream::PqStream,
waiters::{self, Waiter, Waiters},
};
@@ -78,32 +76,158 @@ impl From<DatabaseInfo> for tokio_postgres::Config {
}
}
pub(super) async fn handle_user(
config: &ProxyConfig,
client: &mut PqStream<impl AsyncRead + AsyncWrite + Unpin + Send>,
creds: ClientCredentials,
) -> super::Result<compute::NodeInfo> {
use AuthBackendType::*;
match config.auth_backend {
LegacyConsole => {
legacy_console::handle_user(
&config.auth_endpoint,
&config.auth_link_uri,
client,
&creds,
)
.await
/// This type serves two purposes:
///
/// * When `T` is `()`, it's just a regular auth backend selector
/// which we use in [`crate::config::ProxyConfig`].
///
/// * However, when we substitute `T` with [`ClientCredentials`],
/// this helps us provide the credentials only to those auth
/// backends which require them for the authentication process.
#[derive(Debug, Clone, Copy, PartialEq)]
pub enum BackendType<T> {
/// Legacy Cloud API (V1) + link auth.
LegacyConsole(T),
/// Current Cloud API (V2).
Console(T),
/// Local mock of Cloud API (V2).
Postgres(T),
/// Authentication via a web browser.
Link,
}
impl<T> BackendType<T> {
/// Very similar to [`std::option::Option::map`].
/// Maps [`BackendType<T>`] to [`BackendType<R>`] by applying
/// a function to a contained value.
pub fn map<R>(self, f: impl FnOnce(T) -> R) -> BackendType<R> {
use BackendType::*;
match self {
LegacyConsole(x) => LegacyConsole(f(x)),
Console(x) => Console(f(x)),
Postgres(x) => Postgres(f(x)),
Link => Link,
}
}
}
impl<T, E> BackendType<Result<T, E>> {
/// Very similar to [`std::option::Option::transpose`].
/// This is most useful for error handling.
pub fn transpose(self) -> Result<BackendType<T>, E> {
use BackendType::*;
match self {
LegacyConsole(x) => x.map(LegacyConsole),
Console(x) => x.map(Console),
Postgres(x) => x.map(Postgres),
Link => Ok(Link),
}
}
}
impl BackendType<ClientCredentials> {
/// Authenticate the client via the requested backend, possibly using credentials.
pub async fn authenticate(
mut self,
urls: &config::AuthUrls,
client: &mut PqStream<impl AsyncRead + AsyncWrite + Unpin + Send>,
) -> super::Result<compute::NodeInfo> {
use BackendType::*;
if let Console(creds) | Postgres(creds) = &mut self {
// If there's no project so far, that entails that client doesn't
// support SNI or other means of passing the project name.
// We now expect to see a very specific payload in the place of password.
if creds.project().is_none() {
let payload = AuthFlow::new(client)
.begin(auth::PasswordHack)
.await?
.authenticate()
.await?;
// Finally we may finish the initialization of `creds`.
// TODO: add missing type safety to ClientCredentials.
creds.project = Some(payload.project);
let mut config = match &self {
Console(creds) => {
console::Api::new(&urls.auth_endpoint, creds)
.wake_compute()
.await?
}
Postgres(creds) => {
postgres::Api::new(&urls.auth_endpoint, creds)
.wake_compute()
.await?
}
_ => unreachable!("see the patterns above"),
};
// We should use a password from payload as well.
config.password(payload.password);
return Ok(compute::NodeInfo {
reported_auth_ok: false,
config,
});
}
}
match self {
LegacyConsole(creds) => {
legacy_console::handle_user(
&urls.auth_endpoint,
&urls.auth_link_uri,
&creds,
client,
)
.await
}
Console(creds) => {
console::Api::new(&urls.auth_endpoint, &creds)
.handle_user(client)
.await
}
Postgres(creds) => {
postgres::Api::new(&urls.auth_endpoint, &creds)
.handle_user(client)
.await
}
// NOTE: this auth backend doesn't use client credentials.
Link => link::handle_user(&urls.auth_link_uri, client).await,
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_backend_type_map() {
let values = [
BackendType::LegacyConsole(0),
BackendType::Console(0),
BackendType::Postgres(0),
BackendType::Link,
];
for value in values {
assert_eq!(value.map(|x| x), value);
}
}
#[test]
fn test_backend_type_transpose() {
let values = [
BackendType::LegacyConsole(Ok::<_, ()>(0)),
BackendType::Console(Ok(0)),
BackendType::Postgres(Ok(0)),
BackendType::Link,
];
for value in values {
assert_eq!(value.map(Result::unwrap), value.transpose().unwrap());
}
Console => {
console::Api::new(&config.auth_endpoint, &creds)?
.handle_user(client)
.await
}
Postgres => {
postgres::Api::new(&config.auth_endpoint, &creds)?
.handle_user(client)
.await
}
Link => link::handle_user(&config.auth_link_uri, client).await,
}
}

View File

@@ -1,18 +1,17 @@
//! Cloud API V2.
use crate::{
auth::{self, AuthFlow, ClientCredentials, DatabaseInfo},
compute,
error::UserFacingError,
auth::{self, AuthFlow, ClientCredentials},
compute::{self, ComputeConnCfg},
error::{io_error, UserFacingError},
scram,
stream::PqStream,
url::ApiUrl,
};
use serde::{Deserialize, Serialize};
use std::{future::Future, io};
use std::future::Future;
use thiserror::Error;
use tokio::io::{AsyncRead, AsyncWrite};
use utils::pq_proto::{BeMessage as Be, BeParameterStatusMessage};
pub type Result<T> = std::result::Result<T, ConsoleAuthError>;
@@ -84,8 +83,8 @@ pub(super) struct Api<'a> {
impl<'a> Api<'a> {
/// Construct an API object containing the auth parameters.
pub(super) fn new(endpoint: &'a ApiUrl, creds: &'a ClientCredentials) -> Result<Self> {
Ok(Self { endpoint, creds })
pub(super) fn new(endpoint: &'a ApiUrl, creds: &'a ClientCredentials) -> Self {
Self { endpoint, creds }
}
/// Authenticate the existing user or throw an error.
@@ -100,7 +99,7 @@ impl<'a> Api<'a> {
let mut url = self.endpoint.clone();
url.path_segments_mut().push("proxy_get_role_secret");
url.query_pairs_mut()
.append_pair("project", self.creds.project_name.as_ref()?)
.append_pair("project", self.creds.project().expect("impossible"))
.append_pair("role", &self.creds.user);
// TODO: use a proper logger
@@ -120,11 +119,11 @@ impl<'a> Api<'a> {
}
/// Wake up the compute node and return the corresponding connection info.
async fn wake_compute(&self) -> Result<DatabaseInfo> {
pub(super) async fn wake_compute(&self) -> Result<ComputeConnCfg> {
let mut url = self.endpoint.clone();
url.path_segments_mut().push("proxy_wake_compute");
let project_name = self.creds.project_name.as_ref()?;
url.query_pairs_mut().append_pair("project", project_name);
url.query_pairs_mut()
.append_pair("project", self.creds.project().expect("impossible"));
// TODO: use a proper logger
println!("cplane request: {url}");
@@ -137,16 +136,20 @@ impl<'a> Api<'a> {
let response: GetWakeComputeResponse =
serde_json::from_str(&resp.text().await.map_err(io_error)?)?;
let (host, port) = parse_host_port(&response.address)
.ok_or(ConsoleAuthError::BadComputeAddress(response.address))?;
// Unfortunately, ownership won't let us use `Option::ok_or` here.
let (host, port) = match parse_host_port(&response.address) {
None => return Err(ConsoleAuthError::BadComputeAddress(response.address)),
Some(x) => x,
};
Ok(DatabaseInfo {
host,
port,
dbname: self.creds.dbname.to_owned(),
user: self.creds.user.to_owned(),
password: None,
})
let mut config = ComputeConnCfg::new();
config
.host(host)
.port(port)
.dbname(&self.creds.dbname)
.user(&self.creds.user);
Ok(config)
}
}
@@ -160,7 +163,7 @@ pub(super) async fn handle_user<'a, Endpoint, GetAuthInfo, WakeCompute>(
) -> auth::Result<compute::NodeInfo>
where
GetAuthInfo: Future<Output = Result<AuthInfo>>,
WakeCompute: Future<Output = Result<DatabaseInfo>>,
WakeCompute: Future<Output = Result<ComputeConnCfg>>,
{
let auth_info = get_auth_info(endpoint).await?;
@@ -179,48 +182,18 @@ where
}
};
client
.write_message_noflush(&Be::AuthenticationOk)?
.write_message_noflush(&BeParameterStatusMessage::encoding())?;
let mut config = wake_compute(endpoint).await?;
if let Some(keys) = scram_keys {
config.auth_keys(tokio_postgres::config::AuthKeys::ScramSha256(keys));
}
Ok(compute::NodeInfo {
db_info: wake_compute(endpoint).await?,
scram_keys,
reported_auth_ok: false,
config,
})
}
/// Upcast (almost) any error into an opaque [`io::Error`].
pub(super) fn io_error(e: impl Into<Box<dyn std::error::Error + Send + Sync>>) -> io::Error {
io::Error::new(io::ErrorKind::Other, e)
}
fn parse_host_port(input: &str) -> Option<(String, u16)> {
fn parse_host_port(input: &str) -> Option<(&str, u16)> {
let (host, port) = input.split_once(':')?;
Some((host.to_owned(), port.parse().ok()?))
}
#[cfg(test)]
mod tests {
use super::*;
use serde_json::json;
#[test]
fn parse_db_info() -> anyhow::Result<()> {
let _: DatabaseInfo = serde_json::from_value(json!({
"host": "localhost",
"port": 5432,
"dbname": "postgres",
"user": "john_doe",
"password": "password",
}))?;
let _: DatabaseInfo = serde_json::from_value(json!({
"host": "localhost",
"port": 5432,
"dbname": "postgres",
"user": "john_doe",
}))?;
Ok(())
}
Some((host, port.parse().ok()?))
}

View File

@@ -11,7 +11,7 @@ use crate::{
use serde::{Deserialize, Serialize};
use thiserror::Error;
use tokio::io::{AsyncRead, AsyncWrite};
use utils::pq_proto::{BeMessage as Be, BeParameterStatusMessage};
use utils::pq_proto::BeMessage as Be;
#[derive(Debug, Error)]
pub enum AuthErrorImpl {
@@ -76,6 +76,12 @@ enum ProxyAuthResponse {
NotReady { ready: bool }, // TODO: get rid of `ready`
}
impl ClientCredentials {
fn is_existing_user(&self) -> bool {
self.user.ends_with("@zenith")
}
}
async fn authenticate_proxy_client(
auth_endpoint: &reqwest::Url,
creds: &ClientCredentials,
@@ -100,7 +106,7 @@ async fn authenticate_proxy_client(
}
let auth_info: ProxyAuthResponse = serde_json::from_str(resp.text().await?.as_str())?;
println!("got auth info: #{:?}", auth_info);
println!("got auth info: {:?}", auth_info);
use ProxyAuthResponse::*;
let db_info = match auth_info {
@@ -128,7 +134,9 @@ async fn handle_existing_user(
// Read client's password hash
let msg = client.read_password_message().await?;
let md5_response = parse_password(&msg).ok_or(auth::AuthErrorImpl::MalformedPassword)?;
let md5_response = parse_password(&msg).ok_or(auth::AuthErrorImpl::MalformedPassword(
"the password should be a valid null-terminated utf-8 string",
))?;
let db_info = authenticate_proxy_client(
auth_endpoint,
@@ -139,21 +147,17 @@ async fn handle_existing_user(
)
.await?;
client
.write_message_noflush(&Be::AuthenticationOk)?
.write_message_noflush(&BeParameterStatusMessage::encoding())?;
Ok(compute::NodeInfo {
db_info,
scram_keys: None,
reported_auth_ok: false,
config: db_info.into(),
})
}
pub async fn handle_user(
auth_endpoint: &reqwest::Url,
auth_link_uri: &reqwest::Url,
client: &mut PqStream<impl AsyncRead + AsyncWrite + Unpin + Send>,
creds: &ClientCredentials,
client: &mut PqStream<impl AsyncRead + AsyncWrite + Unpin + Send>,
) -> auth::Result<compute::NodeInfo> {
if creds.is_existing_user() {
handle_existing_user(auth_endpoint, client, creds).await
@@ -201,4 +205,24 @@ mod tests {
.unwrap();
assert!(matches!(auth, ProxyAuthResponse::NotReady { .. }));
}
#[test]
fn parse_db_info() -> anyhow::Result<()> {
let _: DatabaseInfo = serde_json::from_value(json!({
"host": "localhost",
"port": 5432,
"dbname": "postgres",
"user": "john_doe",
"password": "password",
}))?;
let _: DatabaseInfo = serde_json::from_value(json!({
"host": "localhost",
"port": 5432,
"dbname": "postgres",
"user": "john_doe",
}))?;
Ok(())
}
}

View File

@@ -41,7 +41,7 @@ pub async fn handle_user(
client.write_message_noflush(&Be::NoticeResponse("Connecting to database."))?;
Ok(compute::NodeInfo {
db_info,
scram_keys: None,
reported_auth_ok: true,
config: db_info.into(),
})
}

View File

@@ -3,10 +3,12 @@
use crate::{
auth::{
self,
backend::console::{self, io_error, AuthInfo, Result},
ClientCredentials, DatabaseInfo,
backend::console::{self, AuthInfo, Result},
ClientCredentials,
},
compute, scram,
compute::{self, ComputeConnCfg},
error::io_error,
scram,
stream::PqStream,
url::ApiUrl,
};
@@ -20,8 +22,8 @@ pub(super) struct Api<'a> {
impl<'a> Api<'a> {
/// Construct an API object containing the auth parameters.
pub(super) fn new(endpoint: &'a ApiUrl, creds: &'a ClientCredentials) -> Result<Self> {
Ok(Self { endpoint, creds })
pub(super) fn new(endpoint: &'a ApiUrl, creds: &'a ClientCredentials) -> Self {
Self { endpoint, creds }
}
/// Authenticate the existing user or throw an error.
@@ -56,7 +58,10 @@ impl<'a> Api<'a> {
// We shouldn't get more than one row anyway.
[row, ..] => {
let entry = row.try_get(0).map_err(io_error)?;
let entry = row
.try_get("rolpassword")
.map_err(|e| io_error(format!("failed to read user's password: {e}")))?;
scram::ServerSecret::parse(entry)
.map(AuthInfo::Scram)
.or_else(|| {
@@ -75,14 +80,14 @@ impl<'a> Api<'a> {
}
/// We don't need to wake anything locally, so we just return the connection info.
async fn wake_compute(&self) -> Result<DatabaseInfo> {
Ok(DatabaseInfo {
// TODO: handle that near CLI params parsing
host: self.endpoint.host_str().unwrap_or("localhost").to_owned(),
port: self.endpoint.port().unwrap_or(5432),
dbname: self.creds.dbname.to_owned(),
user: self.creds.user.to_owned(),
password: None,
})
pub(super) async fn wake_compute(&self) -> Result<ComputeConnCfg> {
let mut config = ComputeConnCfg::new();
config
.host(self.endpoint.host_str().unwrap_or("localhost"))
.port(self.endpoint.port().unwrap_or(5432))
.dbname(&self.creds.dbname)
.user(&self.creds.user);
Ok(config)
}
}

View File

@@ -1,39 +1,25 @@
//! User credentials used in authentication.
use crate::compute;
use crate::config::ProxyConfig;
use crate::error::UserFacingError;
use crate::stream::PqStream;
use std::collections::HashMap;
use thiserror::Error;
use tokio::io::{AsyncRead, AsyncWrite};
use utils::pq_proto::StartupMessageParams;
#[derive(Debug, Error, PartialEq, Eq, Clone)]
pub enum ClientCredsParseError {
#[error("Parameter `{0}` is missing in startup packet.")]
#[error("Parameter '{0}' is missing in startup packet.")]
MissingKey(&'static str),
#[error(
"Project name is not specified. \
EITHER please upgrade the postgres client library (libpq) for SNI support \
OR pass the project name as a parameter: '&options=project%3D<project-name>'."
)]
MissingSNIAndProjectName,
#[error("Inconsistent project name inferred from SNI ('{0}') and project option ('{1}').")]
InconsistentProjectNameAndSNI(String, String),
#[error("Common name is not set.")]
CommonNameNotSet,
InconsistentProjectNames(String, String),
#[error(
"SNI ('{1}') inconsistently formatted with respect to common name ('{0}'). \
SNI should be formatted as '<project-name>.<common-name>'."
SNI should be formatted as '<project-name>.{0}'."
)]
InconsistentCommonNameAndSNI(String, String),
InconsistentSni(String, String),
#[error("Project name ('{0}') must contain only alphanumeric characters and hyphens ('-').")]
ProjectNameContainsIllegalChars(String),
#[error("Project name ('{0}') must contain only alphanumeric characters and hyphen.")]
MalformedProjectName(String),
}
impl UserFacingError for ClientCredsParseError {}
@@ -44,286 +30,171 @@ impl UserFacingError for ClientCredsParseError {}
pub struct ClientCredentials {
pub user: String,
pub dbname: String,
pub project_name: Result<String, ClientCredsParseError>,
pub project: Option<String>,
}
impl ClientCredentials {
pub fn is_existing_user(&self) -> bool {
// This logic will likely change in the future.
self.user.ends_with("@zenith")
pub fn project(&self) -> Option<&str> {
self.project.as_deref()
}
}
impl ClientCredentials {
pub fn parse(
mut options: HashMap<String, String>,
sni_data: Option<&str>,
mut options: StartupMessageParams,
sni: Option<&str>,
common_name: Option<&str>,
) -> Result<Self, ClientCredsParseError> {
let mut get_param = |key| {
options
.remove(key)
.ok_or(ClientCredsParseError::MissingKey(key))
};
use ClientCredsParseError::*;
// Some parameters are absolutely necessary, others not so much.
let mut get_param = |key| options.remove(key).ok_or(MissingKey(key));
// Some parameters are stored in the startup message.
let user = get_param("user")?;
let dbname = get_param("database")?;
let project_name = get_param("project").ok();
let project_name = get_project_name(sni_data, common_name, project_name.as_deref());
let project_a = get_param("project").ok();
// Alternative project name is in fact a subdomain from SNI.
// NOTE: we do not consider SNI if `common_name` is missing.
let project_b = sni
.zip(common_name)
.map(|(sni, cn)| {
// TODO: what if SNI is present but just a common name?
subdomain_from_sni(sni, cn)
.ok_or_else(|| InconsistentSni(sni.to_owned(), cn.to_owned()))
})
.transpose()?;
let project = match (project_a, project_b) {
// Invariant: if we have both project name variants, they should match.
(Some(a), Some(b)) if a != b => Some(Err(InconsistentProjectNames(a, b))),
(a, b) => a.or(b).map(|name| {
// Invariant: project name may not contain certain characters.
check_project_name(name).map_err(MalformedProjectName)
}),
}
.transpose()?;
Ok(Self {
user,
dbname,
project_name,
project,
})
}
}
/// Use credentials to authenticate the user.
pub async fn authenticate(
self,
config: &ProxyConfig,
client: &mut PqStream<impl AsyncRead + AsyncWrite + Unpin + Send>,
) -> super::Result<compute::NodeInfo> {
// This method is just a convenient facade for `handle_user`
super::backend::handle_user(config, client, self).await
fn check_project_name(name: String) -> Result<String, String> {
if name.chars().all(|c| c.is_alphanumeric() || c == '-') {
Ok(name)
} else {
Err(name)
}
}
/// Inferring project name from sni_data.
fn project_name_from_sni_data(
sni_data: &str,
common_name: &str,
) -> Result<String, ClientCredsParseError> {
let common_name_with_dot = format!(".{common_name}");
// check that ".{common_name_with_dot}" is the actual suffix in sni_data
if !sni_data.ends_with(&common_name_with_dot) {
return Err(ClientCredsParseError::InconsistentCommonNameAndSNI(
common_name.to_string(),
sni_data.to_string(),
fn subdomain_from_sni(sni: &str, common_name: &str) -> Option<String> {
sni.strip_suffix(common_name)?
.strip_suffix('.')
.map(str::to_owned)
}
#[cfg(test)]
mod tests {
use super::*;
fn make_options<'a, const N: usize>(pairs: [(&'a str, &'a str); N]) -> StartupMessageParams {
StartupMessageParams::from(pairs.map(|(k, v)| (k.to_owned(), v.to_owned())))
}
#[test]
#[ignore = "TODO: fix how database is handled"]
fn parse_bare_minimum() -> anyhow::Result<()> {
// According to postgresql, only `user` should be required.
let options = make_options([("user", "john_doe")]);
// TODO: check that `creds.dbname` is None.
let creds = ClientCredentials::parse(options, None, None)?;
assert_eq!(creds.user, "john_doe");
Ok(())
}
#[test]
fn parse_missing_project() -> anyhow::Result<()> {
let options = make_options([("user", "john_doe"), ("database", "world")]);
let creds = ClientCredentials::parse(options, None, None)?;
assert_eq!(creds.user, "john_doe");
assert_eq!(creds.dbname, "world");
assert_eq!(creds.project, None);
Ok(())
}
#[test]
fn parse_project_from_sni() -> anyhow::Result<()> {
let options = make_options([("user", "john_doe"), ("database", "world")]);
let sni = Some("foo.localhost");
let common_name = Some("localhost");
let creds = ClientCredentials::parse(options, sni, common_name)?;
assert_eq!(creds.user, "john_doe");
assert_eq!(creds.dbname, "world");
assert_eq!(creds.project.as_deref(), Some("foo"));
Ok(())
}
#[test]
fn parse_project_from_options() -> anyhow::Result<()> {
let options = make_options([
("user", "john_doe"),
("database", "world"),
("project", "bar"),
]);
let creds = ClientCredentials::parse(options, None, None)?;
assert_eq!(creds.user, "john_doe");
assert_eq!(creds.dbname, "world");
assert_eq!(creds.project.as_deref(), Some("bar"));
Ok(())
}
#[test]
fn parse_projects_identical() -> anyhow::Result<()> {
let options = make_options([
("user", "john_doe"),
("database", "world"),
("project", "baz"),
]);
let sni = Some("baz.localhost");
let common_name = Some("localhost");
let creds = ClientCredentials::parse(options, sni, common_name)?;
assert_eq!(creds.user, "john_doe");
assert_eq!(creds.dbname, "world");
assert_eq!(creds.project.as_deref(), Some("baz"));
Ok(())
}
#[test]
fn parse_projects_different() {
let options = make_options([
("user", "john_doe"),
("database", "world"),
("project", "first"),
]);
let sni = Some("second.localhost");
let common_name = Some("localhost");
assert!(matches!(
ClientCredentials::parse(options, sni, common_name).expect_err("should fail"),
ClientCredsParseError::InconsistentProjectNames(_, _)
));
}
// return sni_data without the common name suffix.
Ok(sni_data
.strip_suffix(&common_name_with_dot)
.unwrap()
.to_string())
}
#[cfg(test)]
mod tests_for_project_name_from_sni_data {
use super::*;
#[test]
fn passing() {
let target_project_name = "my-project-123";
let common_name = "localtest.me";
let sni_data = format!("{target_project_name}.{common_name}");
assert_eq!(
project_name_from_sni_data(&sni_data, common_name),
Ok(target_project_name.to_string())
);
}
#[test]
fn throws_inconsistent_common_name_and_sni_data() {
let target_project_name = "my-project-123";
let common_name = "localtest.me";
let wrong_suffix = "wrongtest.me";
assert_eq!(common_name.len(), wrong_suffix.len());
let wrong_common_name = format!("wrong{wrong_suffix}");
let sni_data = format!("{target_project_name}.{wrong_common_name}");
assert_eq!(
project_name_from_sni_data(&sni_data, common_name),
Err(ClientCredsParseError::InconsistentCommonNameAndSNI(
common_name.to_string(),
sni_data
))
);
}
}
/// Determine project name from SNI or from project_name parameter from options argument.
fn get_project_name(
sni_data: Option<&str>,
common_name: Option<&str>,
project_name: Option<&str>,
) -> Result<String, ClientCredsParseError> {
// determine the project name from sni_data if it exists, otherwise from project_name.
let ret = match sni_data {
Some(sni_data) => {
let common_name = common_name.ok_or(ClientCredsParseError::CommonNameNotSet)?;
let project_name_from_sni = project_name_from_sni_data(sni_data, common_name)?;
// check invariant: project name from options and from sni should match
if let Some(project_name) = &project_name {
if !project_name_from_sni.eq(project_name) {
return Err(ClientCredsParseError::InconsistentProjectNameAndSNI(
project_name_from_sni,
project_name.to_string(),
));
}
}
project_name_from_sni
}
None => project_name
.ok_or(ClientCredsParseError::MissingSNIAndProjectName)?
.to_string(),
};
// check formatting invariant: project name must contain only alphanumeric characters and hyphens.
if !ret.chars().all(|x: char| x.is_alphanumeric() || x == '-') {
return Err(ClientCredsParseError::ProjectNameContainsIllegalChars(ret));
}
Ok(ret)
}
#[cfg(test)]
mod tests_for_project_name_only {
use super::*;
#[test]
fn passing_from_sni_data_only() {
let target_project_name = "my-project-123";
let common_name = "localtest.me";
let sni_data = format!("{target_project_name}.{common_name}");
assert_eq!(
get_project_name(Some(&sni_data), Some(common_name), None),
Ok(target_project_name.to_string())
);
}
#[test]
fn throws_project_name_contains_illegal_chars_from_sni_data_only() {
let project_name_prefix = "my-project";
let project_name_suffix = "123";
let common_name = "localtest.me";
for illegal_char_id in 0..256 {
let illegal_char = char::from_u32(illegal_char_id).unwrap();
if !(illegal_char.is_alphanumeric() || illegal_char == '-')
&& illegal_char.to_string().len() == 1
{
let target_project_name =
format!("{project_name_prefix}{illegal_char}{project_name_suffix}");
let sni_data = format!("{target_project_name}.{common_name}");
assert_eq!(
get_project_name(Some(&sni_data), Some(common_name), None),
Err(ClientCredsParseError::ProjectNameContainsIllegalChars(
target_project_name
))
);
}
}
}
#[test]
fn passing_from_project_name_only() {
let target_project_name = "my-project-123";
let common_names = [Some("localtest.me"), None];
for common_name in common_names {
assert_eq!(
get_project_name(None, common_name, Some(target_project_name)),
Ok(target_project_name.to_string())
);
}
}
#[test]
fn throws_project_name_contains_illegal_chars_from_project_name_only() {
let project_name_prefix = "my-project";
let project_name_suffix = "123";
let common_names = [Some("localtest.me"), None];
for common_name in common_names {
for illegal_char_id in 0..256 {
let illegal_char: char = char::from_u32(illegal_char_id).unwrap();
if !(illegal_char.is_alphanumeric() || illegal_char == '-')
&& illegal_char.to_string().len() == 1
{
let target_project_name =
format!("{project_name_prefix}{illegal_char}{project_name_suffix}");
assert_eq!(
get_project_name(None, common_name, Some(&target_project_name)),
Err(ClientCredsParseError::ProjectNameContainsIllegalChars(
target_project_name
))
);
}
}
}
}
#[test]
fn passing_from_sni_data_and_project_name() {
let target_project_name = "my-project-123";
let common_name = "localtest.me";
let sni_data = format!("{target_project_name}.{common_name}");
assert_eq!(
get_project_name(
Some(&sni_data),
Some(common_name),
Some(target_project_name)
),
Ok(target_project_name.to_string())
);
}
#[test]
fn throws_inconsistent_project_name_and_sni() {
let project_name_param = "my-project-123";
let wrong_project_name = "not-my-project-123";
let common_name = "localtest.me";
let sni_data = format!("{wrong_project_name}.{common_name}");
assert_eq!(
get_project_name(Some(&sni_data), Some(common_name), Some(project_name_param)),
Err(ClientCredsParseError::InconsistentProjectNameAndSNI(
wrong_project_name.to_string(),
project_name_param.to_string()
))
);
}
#[test]
fn throws_common_name_not_set() {
let target_project_name = "my-project-123";
let wrong_project_name = "not-my-project-123";
let common_name = "localtest.me";
let sni_datas = [
Some(format!("{wrong_project_name}.{common_name}")),
Some(format!("{target_project_name}.{common_name}")),
];
let project_names = [None, Some(target_project_name)];
for sni_data in sni_datas {
for project_name_param in project_names {
assert_eq!(
get_project_name(sni_data.as_deref(), None, project_name_param),
Err(ClientCredsParseError::CommonNameNotSet)
);
}
}
}
#[test]
fn throws_inconsistent_common_name_and_sni_data() {
let target_project_name = "my-project-123";
let wrong_project_name = "not-my-project-123";
let common_name = "localtest.me";
let wrong_suffix = "wrongtest.me";
assert_eq!(common_name.len(), wrong_suffix.len());
let wrong_common_name = format!("wrong{wrong_suffix}");
let sni_datas = [
Some(format!("{wrong_project_name}.{wrong_common_name}")),
Some(format!("{target_project_name}.{wrong_common_name}")),
];
let project_names = [None, Some(target_project_name)];
for project_name_param in project_names {
for sni_data in &sni_datas {
assert_eq!(
get_project_name(sni_data.as_deref(), Some(common_name), project_name_param),
Err(ClientCredsParseError::InconsistentCommonNameAndSNI(
common_name.to_string(),
sni_data.clone().unwrap().to_string()
))
);
}
}
}
}

View File

@@ -1,8 +1,7 @@
//! Main authentication flow.
use super::AuthErrorImpl;
use crate::stream::PqStream;
use crate::{sasl, scram};
use super::{AuthErrorImpl, PasswordHackPayload};
use crate::{sasl, scram, stream::PqStream};
use std::io;
use tokio::io::{AsyncRead, AsyncWrite};
use utils::pq_proto::{BeAuthenticationSaslMessage, BeMessage, BeMessage as Be};
@@ -27,6 +26,17 @@ impl AuthMethod for Scram<'_> {
}
}
/// Use an ad hoc auth flow (for clients which don't support SNI) proposed in
/// <https://github.com/neondatabase/cloud/issues/1620#issuecomment-1165332290>.
pub struct PasswordHack;
impl AuthMethod for PasswordHack {
#[inline(always)]
fn first_message(&self) -> BeMessage<'_> {
Be::AuthenticationCleartextPassword
}
}
/// This wrapper for [`PqStream`] performs client authentication.
#[must_use]
pub struct AuthFlow<'a, Stream, State> {
@@ -57,13 +67,34 @@ impl<'a, S: AsyncWrite + Unpin> AuthFlow<'a, S, Begin> {
}
}
impl<S: AsyncRead + AsyncWrite + Unpin> AuthFlow<'_, S, PasswordHack> {
/// Perform user authentication. Raise an error in case authentication failed.
pub async fn authenticate(self) -> super::Result<PasswordHackPayload> {
let msg = self.stream.read_password_message().await?;
let password = msg
.strip_suffix(&[0])
.ok_or(AuthErrorImpl::MalformedPassword("missing terminator"))?;
// The so-called "password" should contain a base64-encoded json.
// We will use it later to route the client to their project.
let bytes = base64::decode(password)
.map_err(|_| AuthErrorImpl::MalformedPassword("bad encoding"))?;
let payload = serde_json::from_slice(&bytes)
.map_err(|_| AuthErrorImpl::MalformedPassword("invalid payload"))?;
Ok(payload)
}
}
/// Stream wrapper for handling [SCRAM](crate::scram) auth.
impl<S: AsyncRead + AsyncWrite + Unpin> AuthFlow<'_, S, Scram<'_>> {
/// Perform user authentication. Raise an error in case authentication failed.
pub async fn authenticate(self) -> super::Result<scram::ScramKey> {
// Initial client message contains the chosen auth method's name.
let msg = self.stream.read_password_message().await?;
let sasl = sasl::FirstMessage::parse(&msg).ok_or(AuthErrorImpl::MalformedPassword)?;
let sasl = sasl::FirstMessage::parse(&msg)
.ok_or(AuthErrorImpl::MalformedPassword("bad sasl message"))?;
// Currently, the only supported SASL method is SCRAM.
if !scram::METHODS.contains(&sasl.method) {

View File

@@ -0,0 +1,102 @@
//! Payload for ad hoc authentication method for clients that don't support SNI.
//! See the `impl` for [`super::backend::BackendType<ClientCredentials>`].
//! Read more: <https://github.com/neondatabase/cloud/issues/1620#issuecomment-1165332290>.
use serde::{de, Deserialize, Deserializer};
use std::fmt;
#[derive(Deserialize)]
#[serde(untagged)]
pub enum Password {
/// A regular string for utf-8 encoded passwords.
Simple { password: String },
/// Password is base64-encoded because it may contain arbitrary byte sequences.
Encoded {
#[serde(rename = "password_", deserialize_with = "deserialize_base64")]
password: Vec<u8>,
},
}
impl AsRef<[u8]> for Password {
fn as_ref(&self) -> &[u8] {
match self {
Password::Simple { password } => password.as_ref(),
Password::Encoded { password } => password.as_ref(),
}
}
}
#[derive(Deserialize)]
pub struct PasswordHackPayload {
pub project: String,
#[serde(flatten)]
pub password: Password,
}
fn deserialize_base64<'a, D: Deserializer<'a>>(des: D) -> Result<Vec<u8>, D::Error> {
// It's very tempting to replace this with
//
// ```
// let base64: &str = Deserialize::deserialize(des)?;
// base64::decode(base64).map_err(serde::de::Error::custom)
// ```
//
// Unfortunately, we can't always deserialize into `&str`, so we'd
// have to use an allocating `String` instead. Thus, visitor is better.
struct Visitor;
impl<'de> de::Visitor<'de> for Visitor {
type Value = Vec<u8>;
fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
formatter.write_str("a string")
}
fn visit_str<E: de::Error>(self, v: &str) -> Result<Self::Value, E> {
base64::decode(v).map_err(de::Error::custom)
}
}
des.deserialize_str(Visitor)
}
#[cfg(test)]
mod tests {
use super::*;
use rstest::rstest;
use serde_json::json;
#[test]
fn parse_password() -> anyhow::Result<()> {
let password: Password = serde_json::from_value(json!({
"password": "foo",
}))?;
assert_eq!(password.as_ref(), "foo".as_bytes());
let password: Password = serde_json::from_value(json!({
"password_": base64::encode("foo"),
}))?;
assert_eq!(password.as_ref(), "foo".as_bytes());
Ok(())
}
#[rstest]
#[case("password", str::to_owned)]
#[case("password_", base64::encode)]
fn parse(#[case] key: &str, #[case] encode: fn(&'static str) -> String) -> anyhow::Result<()> {
let (password, project) = ("password", "pie-in-the-sky");
let payload = json!({
"project": project,
key: encode(password),
});
let payload: PasswordHackPayload = serde_json::from_value(payload)?;
assert_eq!(payload.password.as_ref(), password.as_bytes());
assert_eq!(payload.project, project);
Ok(())
}
}

View File

@@ -1,8 +1,6 @@
use crate::auth::DatabaseInfo;
use crate::cancellation::CancelClosure;
use crate::error::UserFacingError;
use std::io;
use std::net::SocketAddr;
use crate::{cancellation::CancelClosure, error::UserFacingError};
use futures::TryFutureExt;
use std::{io, net::SocketAddr};
use thiserror::Error;
use tokio::net::TcpStream;
use tokio_postgres::NoTls;
@@ -21,44 +19,96 @@ pub enum ConnectionError {
FailedToFetchPgVersion,
}
impl UserFacingError for ConnectionError {}
/// PostgreSQL version as [`String`].
pub type Version = String;
impl UserFacingError for ConnectionError {
fn to_string_client(&self) -> String {
use ConnectionError::*;
match self {
// This helps us drop irrelevant library-specific prefixes.
// TODO: propagate severity level and other parameters.
Postgres(err) => match err.as_db_error() {
Some(err) => err.message().to_string(),
None => err.to_string(),
},
other => other.to_string(),
}
}
}
/// A pair of `ClientKey` & `ServerKey` for `SCRAM-SHA-256`.
pub type ScramKeys = tokio_postgres::config::ScramKeys<32>;
/// Compute node connection params.
pub type ComputeConnCfg = tokio_postgres::Config;
/// Various compute node info for establishing connection etc.
pub struct NodeInfo {
pub db_info: DatabaseInfo,
pub scram_keys: Option<ScramKeys>,
/// Did we send [`utils::pq_proto::BeMessage::AuthenticationOk`]?
pub reported_auth_ok: bool,
/// Compute node connection params.
pub config: tokio_postgres::Config,
}
impl NodeInfo {
async fn connect_raw(&self) -> io::Result<(SocketAddr, TcpStream)> {
let host_port = (self.db_info.host.as_str(), self.db_info.port);
let socket = TcpStream::connect(host_port).await?;
let socket_addr = socket.peer_addr()?;
socket2::SockRef::from(&socket).set_keepalive(true)?;
use tokio_postgres::config::Host;
Ok((socket_addr, socket))
let connect_once = |host, port| {
TcpStream::connect((host, port)).and_then(|socket| async {
let socket_addr = socket.peer_addr()?;
// This prevents load balancer from severing the connection.
socket2::SockRef::from(&socket).set_keepalive(true)?;
Ok((socket_addr, socket))
})
};
// We can't reuse connection establishing logic from `tokio_postgres` here,
// because it has no means for extracting the underlying socket which we
// require for our business.
let mut connection_error = None;
let ports = self.config.get_ports();
for (i, host) in self.config.get_hosts().iter().enumerate() {
let port = ports.get(i).or_else(|| ports.get(0)).unwrap_or(&5432);
let host = match host {
Host::Tcp(host) => host.as_str(),
Host::Unix(_) => continue, // unix sockets are not welcome here
};
// TODO: maybe we should add a timeout.
match connect_once(host, *port).await {
Ok(socket) => return Ok(socket),
Err(err) => {
// We can't throw an error here, as there might be more hosts to try.
println!("failed to connect to compute `{host}:{port}`: {err}");
connection_error = Some(err);
}
}
}
Err(connection_error.unwrap_or_else(|| {
io::Error::new(
io::ErrorKind::Other,
format!("couldn't connect: bad compute config: {:?}", self.config),
)
}))
}
}
pub struct PostgresConnection {
/// Socket connected to a compute node.
pub stream: TcpStream,
/// PostgreSQL version of this instance.
pub version: String,
}
impl NodeInfo {
/// Connect to a corresponding compute node.
pub async fn connect(self) -> Result<(TcpStream, Version, CancelClosure), ConnectionError> {
let (socket_addr, mut socket) = self
pub async fn connect(&self) -> Result<(PostgresConnection, CancelClosure), ConnectionError> {
let (socket_addr, mut stream) = self
.connect_raw()
.await
.map_err(|_| ConnectionError::FailedToConnectToCompute)?;
let mut config = tokio_postgres::Config::from(self.db_info);
if let Some(scram_keys) = self.scram_keys {
config.auth_keys(tokio_postgres::config::AuthKeys::ScramSha256(scram_keys));
}
// TODO: establish a secure connection to the DB
let (client, conn) = config.connect_raw(&mut socket, NoTls).await?;
let (client, conn) = self.config.connect_raw(&mut stream, NoTls).await?;
let version = conn
.parameter("server_version")
.ok_or(ConnectionError::FailedToFetchPgVersion)?
@@ -66,6 +116,8 @@ impl NodeInfo {
let cancel_closure = CancelClosure::new(socket_addr, client.cancel_token());
Ok((socket, version, cancel_closure))
let db = PostgresConnection { stream, version };
Ok((db, cancel_closure))
}
}

View File

@@ -1,28 +1,16 @@
use crate::url::ApiUrl;
use crate::{auth, url::ApiUrl};
use anyhow::{bail, ensure, Context};
use std::{str::FromStr, sync::Arc};
#[derive(Debug)]
pub enum AuthBackendType {
/// Legacy Cloud API (V1).
LegacyConsole,
/// Authentication via a web browser.
Link,
/// Current Cloud API (V2).
Console,
/// Local mock of Cloud API (V2).
Postgres,
}
impl FromStr for AuthBackendType {
impl FromStr for auth::BackendType<()> {
type Err = anyhow::Error;
fn from_str(s: &str) -> anyhow::Result<Self> {
use AuthBackendType::*;
use auth::BackendType::*;
Ok(match s {
"legacy" => LegacyConsole,
"console" => Console,
"postgres" => Postgres,
"legacy" => LegacyConsole(()),
"console" => Console(()),
"postgres" => Postgres(()),
"link" => Link,
_ => bail!("Invalid option `{s}` for auth method"),
})
@@ -31,7 +19,11 @@ impl FromStr for AuthBackendType {
pub struct ProxyConfig {
pub tls_config: Option<TlsConfig>,
pub auth_backend: AuthBackendType,
pub auth_backend: auth::BackendType<()>,
pub auth_urls: AuthUrls,
}
pub struct AuthUrls {
pub auth_endpoint: ApiUrl,
pub auth_link_uri: ApiUrl,
}
@@ -87,10 +79,8 @@ pub fn configure_tls(key_path: &str, cert_path: &str) -> anyhow::Result<TlsConfi
"Failed to parse PEM object from bytes from file at '{cert_path}'."
))?
.1;
let almost_common_name = pem.parse_x509()?.tbs_certificate.subject.to_string();
let expected_prefix = "CN=*.";
let common_name = almost_common_name.strip_prefix(expected_prefix);
common_name.map(str::to_string)
let common_name = pem.parse_x509()?.subject().to_string();
common_name.strip_prefix("CN=*.").map(|s| s.to_string())
};
Ok(TlsConfig {

View File

@@ -1,3 +1,5 @@
use std::io;
/// Marks errors that may be safely shown to a client.
/// This trait can be seen as a specialized version of [`ToString`].
///
@@ -15,3 +17,8 @@ pub trait UserFacingError: ToString {
self.to_string()
}
}
/// Upcast (almost) any error into an opaque [`io::Error`].
pub fn io_error(e: impl Into<Box<dyn std::error::Error + Send + Sync>>) -> io::Error {
io::Error::new(io::ErrorKind::Other, e)
}

View File

@@ -118,11 +118,15 @@ async fn main() -> anyhow::Result<()> {
let mgmt_address: SocketAddr = arg_matches.value_of("mgmt").unwrap().parse()?;
let http_address: SocketAddr = arg_matches.value_of("http").unwrap().parse()?;
let auth_urls = config::AuthUrls {
auth_endpoint: arg_matches.value_of("auth-endpoint").unwrap().parse()?,
auth_link_uri: arg_matches.value_of("uri").unwrap().parse()?,
};
let config: &ProxyConfig = Box::leak(Box::new(ProxyConfig {
tls_config,
auth_backend: arg_matches.value_of("auth-backend").unwrap().parse()?,
auth_endpoint: arg_matches.value_of("auth-endpoint").unwrap().parse()?,
auth_link_uri: arg_matches.value_of("uri").unwrap().parse()?,
auth_urls,
}));
println!("Version: {GIT_VERSION}");

View File

@@ -82,11 +82,22 @@ async fn handle_client(
}
let tls = config.tls_config.as_ref();
let (stream, creds) = match handshake(stream, tls, cancel_map).await? {
let (mut stream, params) = match handshake(stream, tls, cancel_map).await? {
Some(x) => x,
None => return Ok(()), // it's a cancellation request
};
let creds = {
let sni = stream.get_ref().sni_hostname();
let common_name = tls.and_then(|tls| tls.common_name.as_deref());
let result = config
.auth_backend
.map(|_| auth::ClientCredentials::parse(params, sni, common_name))
.transpose();
async { result }.or_else(|e| stream.throw_error(e)).await?
};
let client = Client::new(stream, creds);
cancel_map
.with_session(|session| client.connect_to_db(config, session))
@@ -101,12 +112,10 @@ async fn handshake<S: AsyncRead + AsyncWrite + Unpin>(
stream: S,
mut tls: Option<&TlsConfig>,
cancel_map: &CancelMap,
) -> anyhow::Result<Option<(PqStream<Stream<S>>, auth::ClientCredentials)>> {
) -> anyhow::Result<Option<(PqStream<Stream<S>>, StartupMessageParams)>> {
// Client may try upgrading to each protocol only once
let (mut tried_ssl, mut tried_gss) = (false, false);
let common_name = tls.and_then(|cfg| cfg.common_name.as_deref());
let mut stream = PqStream::new(Stream::from_raw(stream));
loop {
let msg = stream.read_startup_packet().await?;
@@ -147,18 +156,7 @@ async fn handshake<S: AsyncRead + AsyncWrite + Unpin>(
stream.throw_error_str(ERR_INSECURE_CONNECTION).await?;
}
// Get SNI info when available
let sni_data = match stream.get_ref() {
Stream::Tls { tls } => tls.get_ref().1.sni_hostname().map(|s| s.to_owned()),
_ => None,
};
// Construct credentials
let creds =
auth::ClientCredentials::parse(params, sni_data.as_deref(), common_name);
let creds = async { creds }.or_else(|e| stream.throw_error(e)).await?;
break Ok(Some((stream, creds)));
break Ok(Some((stream, params)));
}
CancelRequest(cancel_key_data) => {
cancel_map.cancel_session(cancel_key_data).await?;
@@ -174,12 +172,12 @@ struct Client<S> {
/// The underlying libpq protocol stream.
stream: PqStream<S>,
/// Client credentials that we care about.
creds: auth::ClientCredentials,
creds: auth::BackendType<auth::ClientCredentials>,
}
impl<S> Client<S> {
/// Construct a new connection context.
fn new(stream: PqStream<S>, creds: auth::ClientCredentials) -> Self {
fn new(stream: PqStream<S>, creds: auth::BackendType<auth::ClientCredentials>) -> Self {
Self { stream, creds }
}
}
@@ -194,16 +192,22 @@ impl<S: AsyncRead + AsyncWrite + Unpin + Send> Client<S> {
let Self { mut stream, creds } = self;
// Authenticate and connect to a compute node.
let auth = creds.authenticate(config, &mut stream).await;
let auth = creds.authenticate(&config.auth_urls, &mut stream).await;
let node = async { auth }.or_else(|e| stream.throw_error(e)).await?;
let (db, version, cancel_closure) =
node.connect().or_else(|e| stream.throw_error(e)).await?;
let (db, cancel_closure) = node.connect().or_else(|e| stream.throw_error(e)).await?;
let cancel_key_data = session.enable_cancellation(cancel_closure);
// Report authentication success if we haven't done this already.
if !node.reported_auth_ok {
stream
.write_message_noflush(&Be::AuthenticationOk)?
.write_message_noflush(&BeParameterStatusMessage::encoding())?;
}
stream
.write_message_noflush(&BeMessage::ParameterStatus(
BeParameterStatusMessage::ServerVersion(&version),
BeParameterStatusMessage::ServerVersion(&db.version),
))?
.write_message_noflush(&Be::BackendKeyData(cancel_key_data))?
.write_message(&BeMessage::ReadyForQuery)
@@ -217,7 +221,7 @@ impl<S: AsyncRead + AsyncWrite + Unpin + Send> Client<S> {
}
// Starting from here we only proxy the client's traffic.
let mut db = MetricsStream::new(db, inc_proxied);
let mut db = MetricsStream::new(db.stream, inc_proxied);
let mut client = MetricsStream::new(stream.into_inner(), inc_proxied);
let _ = tokio::io::copy_bidirectional(&mut client, &mut db).await?;
@@ -279,9 +283,13 @@ mod tests {
let config = rustls::ServerConfig::builder()
.with_safe_defaults()
.with_no_client_auth()
.with_single_cert(vec![cert], key)?;
.with_single_cert(vec![cert], key)?
.into();
config.into()
TlsConfig {
config,
common_name: Some(common_name.to_string()),
}
};
let client_config = {
@@ -297,11 +305,6 @@ mod tests {
ClientConfig { config, hostname }
};
let tls_config = TlsConfig {
config: tls_config,
common_name: Some(common_name.to_string()),
};
Ok((client_config, tls_config))
}
@@ -357,7 +360,7 @@ mod tests {
auth: impl TestAuth + Send,
) -> anyhow::Result<()> {
let cancel_map = CancelMap::default();
let (mut stream, _creds) = handshake(client, tls.as_ref(), &cancel_map)
let (mut stream, _params) = handshake(client, tls.as_ref(), &cancel_map)
.await?
.context("handshake failed")?;
@@ -436,32 +439,6 @@ mod tests {
proxy.await?
}
#[tokio::test]
async fn give_user_an_error_for_bad_creds() -> anyhow::Result<()> {
let (client, server) = tokio::io::duplex(1024);
let proxy = tokio::spawn(dummy_proxy(client, None, NoAuth));
let client_err = tokio_postgres::Config::new()
.ssl_mode(SslMode::Disable)
.connect_raw(server, NoTls)
.await
.err() // -> Option<E>
.context("client shouldn't be able to connect")?;
// TODO: this is ugly, but `format!` won't allow us to extract fmt string
assert!(client_err.to_string().contains("missing in startup packet"));
let server_err = proxy
.await?
.err() // -> Option<E>
.context("server shouldn't accept client")?;
assert!(client_err.to_string().contains(&server_err.to_string()));
Ok(())
}
#[tokio::test]
async fn keepalive_is_inherited() -> anyhow::Result<()> {
use tokio::net::{TcpListener, TcpStream};

View File

@@ -145,6 +145,14 @@ impl<S> Stream<S> {
pub fn from_raw(raw: S) -> Self {
Self::Raw { raw }
}
/// Return SNI hostname when it's available.
pub fn sni_hostname(&self) -> Option<&str> {
match self {
Stream::Raw { .. } => None,
Stream::Tls { tls } => tls.get_ref().1.sni_hostname(),
}
}
}
#[derive(Debug, Error)]

View File

@@ -20,7 +20,6 @@ postgres = { git = "https://github.com/zenithdb/rust-postgres.git", rev="d052ee8
anyhow = "1.0"
crc32c = "0.6.0"
humantime = "2.1.0"
walkdir = "2"
url = "2.2.2"
signal-hook = "0.3.10"
serde = { version = "1.0", features = ["derive"] }
@@ -28,11 +27,9 @@ serde_with = "1.12.0"
hex = "0.4.3"
const_format = "0.2.21"
tokio-postgres = { git = "https://github.com/zenithdb/rust-postgres.git", rev="d052ee8b86fff9897c77b0fe89ea9daba0e1fa38" }
tokio-util = { version = "0.7", features = ["io"] }
git-version = "0.3.5"
async-trait = "0.1"
once_cell = "1.10.0"
futures = "0.3.13"
toml_edit = { version = "0.13", features = ["easy"] }
postgres_ffi = { path = "../libs/postgres_ffi" }

View File

@@ -1,9 +1,8 @@
use serde::{Deserialize, Serialize};
use utils::zid::{NodeId, ZTenantId, ZTimelineId};
use utils::zid::{NodeId, ZTimelineId};
#[derive(Serialize, Deserialize)]
pub struct TimelineCreateRequest {
pub tenant_id: ZTenantId,
pub timeline_id: ZTimelineId,
pub peer_ids: Vec<NodeId>,
}

View File

@@ -0,0 +1,365 @@
openapi: "3.0.2"
info:
title: Safekeeper control API
version: "1.0"
servers:
- url: "http://localhost:7676"
paths:
/v1/status:
get:
tags:
- "Info"
summary: Get safekeeper status
description: ""
operationId: v1GetSafekeeperStatus
responses:
"200":
description: Safekeeper status
content:
application/json:
schema:
$ref: "#/components/schemas/SafekeeperStatus"
"403":
$ref: "#/components/responses/ForbiddenError"
default:
$ref: "#/components/responses/GenericError"
/v1/tenant/{tenant_id}:
parameters:
- name: tenant_id
in: path
required: true
schema:
type: string
format: hex
delete:
tags:
- "Tenant"
summary: Delete tenant and all its timelines
description: "Deletes tenant and returns a map of timelines that were deleted along with their statuses"
operationId: v1DeleteTenant
responses:
"200":
description: Tenant deleted
content:
application/json:
schema:
$ref: "#/components/schemas/TenantDeleteResult"
"403":
$ref: "#/components/responses/ForbiddenError"
default:
$ref: "#/components/responses/GenericError"
/v1/tenant/{tenant_id}/timeline:
parameters:
- name: tenant_id
in: path
required: true
schema:
type: string
format: hex
post:
tags:
- "Timeline"
summary: Register new timeline
description: ""
operationId: v1CreateTenantTimeline
requestBody:
content:
application/json:
schema:
$ref: "#/components/schemas/TimelineCreateRequest"
responses:
"201":
description: Timeline created
# TODO: return timeline info?
"403":
$ref: "#/components/responses/ForbiddenError"
default:
$ref: "#/components/responses/GenericError"
/v1/tenant/{tenant_id}/timeline/{timeline_id}:
parameters:
- name: tenant_id
in: path
required: true
schema:
type: string
format: hex
- name: timeline_id
in: path
required: true
schema:
type: string
format: hex
get:
tags:
- "Timeline"
summary: Get timeline information and status
description: ""
operationId: v1GetTenantTimeline
responses:
"200":
description: Timeline status
content:
application/json:
schema:
$ref: "#/components/schemas/TimelineStatus"
"403":
$ref: "#/components/responses/ForbiddenError"
default:
$ref: "#/components/responses/GenericError"
delete:
tags:
- "Timeline"
summary: Delete timeline
description: ""
operationId: v1DeleteTenantTimeline
responses:
"200":
description: Timeline deleted
content:
application/json:
schema:
$ref: "#/components/schemas/TimelineDeleteResult"
"403":
$ref: "#/components/responses/ForbiddenError"
default:
$ref: "#/components/responses/GenericError"
/v1/record_safekeeper_info/{tenant_id}/{timeline_id}:
parameters:
- name: tenant_id
in: path
required: true
schema:
type: string
format: hex
- name: timeline_id
in: path
required: true
schema:
type: string
format: hex
post:
tags:
- "Tests"
summary: Used only in tests to hand craft required data
description: ""
operationId: v1RecordSafekeeperInfo
requestBody:
content:
application/json:
schema:
$ref: "#/components/schemas/SkTimelineInfo"
responses:
"200":
description: Timeline info posted
# TODO: return timeline info?
"403":
$ref: "#/components/responses/ForbiddenError"
default:
$ref: "#/components/responses/GenericError"
components:
securitySchemes:
JWT:
type: http
scheme: bearer
bearerFormat: JWT
schemas:
#
# Requests
#
TimelineCreateRequest:
type: object
required:
- timeline_id
- peer_ids
properties:
timeline_id:
type: string
format: hex
peer_ids:
type: array
items:
type: integer
minimum: 0
SkTimelineInfo:
type: object
required:
- last_log_term
- flush_lsn
- commit_lsn
- backup_lsn
- remote_consistent_lsn
- peer_horizon_lsn
- safekeeper_connstr
properties:
last_log_term:
type: integer
minimum: 0
flush_lsn:
type: string
commit_lsn:
type: string
backup_lsn:
type: string
remote_consistent_lsn:
type: string
peer_horizon_lsn:
type: string
safekeeper_connstr:
type: string
#
# Responses
#
SafekeeperStatus:
type: object
required:
- id
properties:
id:
type: integer
minimum: 0 # kind of unsigned integer
TimelineStatus:
type: object
required:
- timeline_id
- tenant_id
properties:
timeline_id:
type: string
format: hex
tenant_id:
type: string
format: hex
acceptor_state:
$ref: '#/components/schemas/AcceptorStateStatus'
flush_lsn:
type: string
timeline_start_lsn:
type: string
local_start_lsn:
type: string
commit_lsn:
type: string
backup_lsn:
type: string
peer_horizon_lsn:
type: string
remote_consistent_lsn:
type: string
AcceptorStateStatus:
type: object
required:
- term
- epoch
properties:
term:
type: integer
minimum: 0 # kind of unsigned integer
epoch:
type: integer
minimum: 0 # kind of unsigned integer
term_history:
type: array
items:
$ref: '#/components/schemas/TermSwitchEntry'
TermSwitchEntry:
type: object
required:
- term
- lsn
properties:
term:
type: integer
minimum: 0 # kind of unsigned integer
lsn:
type: string
TimelineDeleteResult:
type: object
required:
- dir_existed
- was_active
properties:
dir_existed:
type: boolean
was_active:
type: boolean
TenantDeleteResult:
type: object
additionalProperties:
$ref: "#/components/schemas/TimelineDeleteResult"
example:
57fd1b39f23704a63423de0a8435d85c:
dir_existed: true
was_active: false
67fd1b39f23704a63423gb8435d85c33:
dir_existed: false
was_active: false
#
# Errors
#
GenericErrorContent:
type: object
properties:
msg:
type: string
responses:
#
# Errors
#
GenericError:
description: Generic error response
content:
application/json:
schema:
$ref: "#/components/schemas/GenericErrorContent"
ForbiddenError:
description: Forbidden error response
content:
application/json:
schema:
type: object
required:
- msg
properties:
msg:
type: string
security:
- JWT: []

View File

@@ -126,7 +126,7 @@ async fn timeline_create_handler(mut request: Request<Body>) -> Result<Response<
let request_data: TimelineCreateRequest = json_request(&mut request).await?;
let zttid = ZTenantTimelineId {
tenant_id: request_data.tenant_id,
tenant_id: parse_request_param(&request, "tenant_id")?,
timeline_id: request_data.timeline_id,
};
check_permission(&request, Some(zttid.tenant_id))?;
@@ -214,16 +214,19 @@ pub fn make_router(
}
}))
}
// NB: on any changes do not forget to update the OpenAPI spec
// located nearby (/safekeeper/src/http/openapi_spec.yaml).
router
.data(Arc::new(conf))
.data(auth)
.get("/v1/status", status_handler)
// Will be used in the future instead of implicit timeline creation
.post("/v1/tenant/:tenant_id/timeline", timeline_create_handler)
.get(
"/v1/timeline/:tenant_id/:timeline_id",
"/v1/tenant/:tenant_id/timeline/:timeline_id",
timeline_status_handler,
)
// Will be used in the future instead of implicit timeline creation
.post("/v1/timeline", timeline_create_handler)
.delete(
"/v1/tenant/:tenant_id/timeline/:timeline_id",
timeline_delete_force_handler,

View File

@@ -1,3 +1,6 @@
import threading
import pytest
import time
from fixtures.log_helper import log
from fixtures.neon_fixtures import NeonEnv
from fixtures.utils import lsn_from_hex
@@ -99,3 +102,72 @@ def test_branch_and_gc(neon_simple_env: NeonEnv):
branch_cur.execute('SELECT count(*) FROM foo')
assert branch_cur.fetchone() == (200000, )
# This test simulates a race condition happening when branch creation and GC are performed concurrently.
#
# Suppose we want to create a new timeline 't' from a source timeline 's' starting
# from a lsn 'lsn'. Upon creating 't', if we don't hold the GC lock and compare 'lsn' with
# the latest GC information carefully, it's possible for GC to accidentally remove data
# needed by the new timeline.
#
# In this test, GC is requested before the branch creation but is delayed to happen after branch creation.
# As a result, when doing GC for the source timeline, we don't have any information about
# the upcoming new branches, so it's possible to remove data that may be needed by the new branches.
# It's the branch creation task's job to make sure the starting 'lsn' is not out of scope
# and prevent creating branches with invalid starting LSNs.
#
# For more details, see discussion in https://github.com/neondatabase/neon/pull/2101#issuecomment-1185273447.
def test_branch_creation_before_gc(neon_simple_env: NeonEnv):
env = neon_simple_env
# Disable background GC but set the `pitr_interval` to be small, so GC can delete something
tenant, _ = env.neon_cli.create_tenant(
conf={
# disable background GC
'gc_period': '10 m',
'gc_horizon': f'{10 * 1024 ** 3}',
# small checkpoint distance to create more delta layer files
'checkpoint_distance': f'{1024 ** 2}',
# set the target size to be large to allow the image layer to cover the whole key space
'compaction_target_size': f'{1024 ** 3}',
# tweak the default settings to allow quickly create image layers and L1 layers
'compaction_period': '1 s',
'compaction_threshold': '2',
'image_creation_threshold': '1',
# set PITR interval to be small, so we can do GC
'pitr_interval': '0 s'
})
b0 = env.neon_cli.create_branch('b0', tenant_id=tenant)
pg0 = env.postgres.create_start('b0', tenant_id=tenant)
res = pg0.safe_psql_many(queries=[
"CREATE TABLE t(key serial primary key)",
"INSERT INTO t SELECT FROM generate_series(1, 100000)",
"SELECT pg_current_wal_insert_lsn()",
"INSERT INTO t SELECT FROM generate_series(1, 100000)",
])
lsn = res[2][0][0]
# Use `failpoint=sleep` and `threading` to make the GC iteration triggers *before* the
# branch creation task but the individual timeline GC iteration happens *after*
# the branch creation task.
env.pageserver.safe_psql(f"failpoints before-timeline-gc=sleep(2000)")
def do_gc():
env.pageserver.safe_psql(f"do_gc {tenant.hex} {b0.hex} 0")
thread = threading.Thread(target=do_gc, daemon=True)
thread.start()
# because of network latency and other factors, GC iteration might be processed
# after the `create_branch` request. Add a sleep here to make sure that GC is
# always processed before.
time.sleep(1.0)
# The starting LSN is invalid as the corresponding record is scheduled to be removed by in-queue GC.
with pytest.raises(Exception, match="invalid branch start lsn"):
env.neon_cli.create_branch('b1', 'b0', tenant_id=tenant, ancestor_start_lsn=lsn)

View File

@@ -0,0 +1,82 @@
import time
import os
from fixtures.neon_fixtures import NeonEnvBuilder
from fixtures.log_helper import log
# This test creates large number of tables which cause large catalog.
# Right now Neon serialize directory as single key-value storage entry and so
# it leads to layer filled mostly by one key.
# Originally Neon implementation of checkpoint and compaction is not able to split key which leads
# to large (several gigabytes) layer files (both ephemeral and delta layers).
# It may cause problems with uploading to S3 and also degrade performance because ephemeral file swapping.
#
def test_large_schema(neon_env_builder: NeonEnvBuilder):
env = neon_env_builder.init_start()
pg = env.postgres.create_start('main')
conn = pg.connect()
cur = conn.cursor()
tables = 2 # 10 is too much for debug build
partitions = 1000
for i in range(1, tables + 1):
print(f'iteration {i} / {tables}')
# Restart compute. Restart is actually not strictly needed.
# It is done mostly because this test originally tries to model the problem reported by Ketteq.
pg.stop()
# Kill and restart the pageserver.
# env.pageserver.stop(immediate=True)
# env.pageserver.start()
pg.start()
retry_sleep = 0.5
max_retries = 200
retries = 0
while True:
try:
conn = pg.connect()
cur = conn.cursor()
cur.execute(f"CREATE TABLE if not exists t_{i}(pk integer) partition by range (pk)")
for j in range(1, partitions + 1):
cur.execute(
f"create table if not exists p_{i}_{j} partition of t_{i} for values from ({j}) to ({j + 1})"
)
cur.execute(f"insert into t_{i} values (generate_series(1,{partitions}))")
cur.execute("vacuum full")
conn.close()
except Exception as error:
# It's normal that it takes some time for the pageserver to
# restart, and for the connection to fail until it does. It
# should eventually recover, so retry until it succeeds.
print(f'failed: {error}')
if retries < max_retries:
retries += 1
print(f'retry {retries} / {max_retries}')
time.sleep(retry_sleep)
continue
else:
raise
break
conn = pg.connect()
cur = conn.cursor()
for i in range(1, tables + 1):
cur.execute(f"SELECT count(*) FROM t_{i}")
assert cur.fetchone() == (partitions, )
cur.execute("set enable_sort=off")
cur.execute("select * from pg_depend order by refclassid, refobjid, refobjsubid")
# Check layer file sizes
tenant_id = pg.safe_psql("show neon.tenant_id")[0][0]
timeline_id = pg.safe_psql("show neon.timeline_id")[0][0]
timeline_path = "{}/tenants/{}/timelines/{}/".format(env.repo_dir, tenant_id, timeline_id)
for filename in os.listdir(timeline_path):
if filename.startswith('00000'):
log.info(f'layer {filename} size is {os.path.getsize(timeline_path + filename)}')
assert os.path.getsize(timeline_path + filename) < 512_000_000

View File

@@ -1,8 +1,34 @@
import pytest
import json
import base64
def test_proxy_select_1(static_proxy):
static_proxy.safe_psql("select 1;", options="project=generic-project-name")
static_proxy.safe_psql('select 1', options='project=generic-project-name')
def test_password_hack(static_proxy):
user = 'borat'
password = 'password'
static_proxy.safe_psql(f"create role {user} with login password '{password}'",
options='project=irrelevant')
def encode(s: str) -> str:
return base64.b64encode(s.encode('utf-8')).decode('utf-8')
magic = encode(json.dumps({
'project': 'irrelevant',
'password': password,
}))
static_proxy.safe_psql('select 1', sslsni=0, user=user, password=magic)
magic = encode(json.dumps({
'project': 'irrelevant',
'password_': encode(password),
}))
static_proxy.safe_psql('select 1', sslsni=0, user=user, password=magic)
# Pass extra options to the server.
@@ -11,8 +37,8 @@ def test_proxy_select_1(static_proxy):
# See https://github.com/neondatabase/neon/issues/1287
@pytest.mark.xfail
def test_proxy_options(static_proxy):
with static_proxy.connect(options="-cproxytest.option=value") as conn:
with static_proxy.connect(options='-cproxytest.option=value') as conn:
with conn.cursor() as cur:
cur.execute("SHOW proxytest.option;")
cur.execute('SHOW proxytest.option')
value = cur.fetchall()[0][0]
assert value == 'value'

View File

@@ -26,7 +26,7 @@ from fixtures.neon_fixtures import (
wait_for_upload,
wait_until,
)
from fixtures.utils import lsn_from_hex, subprocess_capture
from fixtures.utils import lsn_from_hex, lsn_to_hex, subprocess_capture
def assert_abs_margin_ratio(a: float, b: float, margin_ratio: float):
@@ -268,6 +268,7 @@ def test_tenant_relocation(neon_env_builder: NeonEnvBuilder,
env.neon_cli.create_branch(
new_branch_name="test_tenant_relocation_second",
ancestor_branch_name="test_tenant_relocation_main",
ancestor_start_lsn=lsn_to_hex(current_lsn_main),
tenant_id=tenant_id,
)
pg_second = env.postgres.create_start(branch_name='test_tenant_relocation_second',

View File

@@ -1,10 +1,15 @@
from contextlib import closing
import pathlib
from uuid import UUID
import re
import psycopg2.extras
import psycopg2.errors
from fixtures.neon_fixtures import NeonEnv, NeonEnvBuilder, Postgres, assert_timeline_local
from fixtures.log_helper import log
import time
from fixtures.utils import get_timeline_dir_size
def test_timeline_size(neon_simple_env: NeonEnv):
env = neon_simple_env
@@ -176,3 +181,129 @@ def test_timeline_size_quota(neon_env_builder: NeonEnvBuilder):
cur.execute("SELECT * from pg_size_pretty(pg_cluster_size())")
pg_cluster_size = cur.fetchone()
log.info(f"pg_cluster_size = {pg_cluster_size}")
def test_timeline_physical_size_init(neon_simple_env: NeonEnv):
env = neon_simple_env
new_timeline_id = env.neon_cli.create_branch('test_timeline_physical_size_init')
pg = env.postgres.create_start("test_timeline_physical_size_init")
pg.safe_psql_many([
"CREATE TABLE foo (t text)",
"""INSERT INTO foo
SELECT 'long string to consume some space' || g
FROM generate_series(1, 1000) g""",
])
# restart the pageserer to force calculating timeline's initial physical size
env.pageserver.stop()
env.pageserver.start()
assert_physical_size(env, env.initial_tenant, new_timeline_id)
def test_timeline_physical_size_post_checkpoint(neon_simple_env: NeonEnv):
env = neon_simple_env
new_timeline_id = env.neon_cli.create_branch('test_timeline_physical_size_post_checkpoint')
pg = env.postgres.create_start("test_timeline_physical_size_post_checkpoint")
pg.safe_psql_many([
"CREATE TABLE foo (t text)",
"""INSERT INTO foo
SELECT 'long string to consume some space' || g
FROM generate_series(1, 1000) g""",
])
env.pageserver.safe_psql(f"checkpoint {env.initial_tenant.hex} {new_timeline_id.hex}")
assert_physical_size(env, env.initial_tenant, new_timeline_id)
def test_timeline_physical_size_post_compaction(neon_env_builder: NeonEnvBuilder):
# Disable background compaction as we don't want it to happen after `get_physical_size` request
# and before checking the expected size on disk, which makes the assertion failed
neon_env_builder.pageserver_config_override = "tenant_config={checkpoint_distance=100000, compaction_period='10m'}"
env = neon_env_builder.init_start()
new_timeline_id = env.neon_cli.create_branch('test_timeline_physical_size_post_compaction')
pg = env.postgres.create_start("test_timeline_physical_size_post_compaction")
pg.safe_psql_many([
"CREATE TABLE foo (t text)",
"""INSERT INTO foo
SELECT 'long string to consume some space' || g
FROM generate_series(1, 100000) g""",
])
env.pageserver.safe_psql(f"checkpoint {env.initial_tenant.hex} {new_timeline_id.hex}")
env.pageserver.safe_psql(f"compact {env.initial_tenant.hex} {new_timeline_id.hex}")
assert_physical_size(env, env.initial_tenant, new_timeline_id)
def test_timeline_physical_size_post_gc(neon_env_builder: NeonEnvBuilder):
# Disable background compaction and GC as we don't want it to happen after `get_physical_size` request
# and before checking the expected size on disk, which makes the assertion failed
neon_env_builder.pageserver_config_override = \
"tenant_config={checkpoint_distance=100000, compaction_period='10m', gc_period='10m', pitr_interval='1s'}"
env = neon_env_builder.init_start()
new_timeline_id = env.neon_cli.create_branch('test_timeline_physical_size_post_gc')
pg = env.postgres.create_start("test_timeline_physical_size_post_gc")
pg.safe_psql_many([
"CREATE TABLE foo (t text)",
"""INSERT INTO foo
SELECT 'long string to consume some space' || g
FROM generate_series(1, 100000) g""",
])
env.pageserver.safe_psql(f"checkpoint {env.initial_tenant.hex} {new_timeline_id.hex}")
pg.safe_psql("""
INSERT INTO foo
SELECT 'long string to consume some space' || g
FROM generate_series(1, 100000) g
""")
env.pageserver.safe_psql(f"checkpoint {env.initial_tenant.hex} {new_timeline_id.hex}")
env.pageserver.safe_psql(f"do_gc {env.initial_tenant.hex} {new_timeline_id.hex} 0")
assert_physical_size(env, env.initial_tenant, new_timeline_id)
def test_timeline_physical_size_metric(neon_simple_env: NeonEnv):
env = neon_simple_env
new_timeline_id = env.neon_cli.create_branch('test_timeline_physical_size_metric')
pg = env.postgres.create_start("test_timeline_physical_size_metric")
pg.safe_psql_many([
"CREATE TABLE foo (t text)",
"""INSERT INTO foo
SELECT 'long string to consume some space' || g
FROM generate_series(1, 100000) g""",
])
env.pageserver.safe_psql(f"checkpoint {env.initial_tenant.hex} {new_timeline_id.hex}")
# get the metrics and parse the metric for the current timeline's physical size
metrics = env.pageserver.http_client().get_metrics()
matches = re.search(
f'^pageserver_current_physical_size{{tenant_id="{env.initial_tenant.hex}",timeline_id="{new_timeline_id.hex}"}} (\\S+)$',
metrics,
re.MULTILINE)
assert matches
# assert that the metric matches the actual physical size on disk
tl_physical_size_metric = int(matches.group(1))
timeline_path = env.timeline_dir(env.initial_tenant, new_timeline_id)
assert tl_physical_size_metric == get_timeline_dir_size(timeline_path)
def assert_physical_size(env: NeonEnv, tenant_id: UUID, timeline_id: UUID):
"""Check the current physical size returned from timeline API
matches the total physical size of the timeline on disk"""
client = env.pageserver.http_client()
res = assert_timeline_local(client, tenant_id, timeline_id)
timeline_path = env.timeline_dir(tenant_id, timeline_id)
assert res["local"]["current_physical_size"] == res["local"][
"current_physical_size_non_incremental"]
assert res["local"]["current_physical_size"] == get_timeline_dir_size(timeline_path)

View File

@@ -203,61 +203,6 @@ def test_restarts(neon_env_builder: NeonEnvBuilder):
assert cur.fetchone() == (500500, )
start_delay_sec = 2
def delayed_safekeeper_start(wa):
time.sleep(start_delay_sec)
wa.start()
# When majority of acceptors is offline, commits are expected to be frozen
def test_unavailability(neon_env_builder: NeonEnvBuilder):
neon_env_builder.num_safekeepers = 2
env = neon_env_builder.init_start()
env.neon_cli.create_branch('test_safekeepers_unavailability')
pg = env.postgres.create_start('test_safekeepers_unavailability')
# we rely upon autocommit after each statement
# as waiting for acceptors happens there
pg_conn = pg.connect()
cur = pg_conn.cursor()
# check basic work with table
cur.execute('CREATE TABLE t(key int primary key, value text)')
cur.execute("INSERT INTO t values (1, 'payload')")
# shutdown one of two acceptors, that is, majority
env.safekeepers[0].stop()
proc = Process(target=delayed_safekeeper_start, args=(env.safekeepers[0], ))
proc.start()
start = time.time()
cur.execute("INSERT INTO t values (2, 'payload')")
# ensure that the query above was hanging while acceptor was down
assert (time.time() - start) >= start_delay_sec
proc.join()
# for the world's balance, do the same with second acceptor
env.safekeepers[1].stop()
proc = Process(target=delayed_safekeeper_start, args=(env.safekeepers[1], ))
proc.start()
start = time.time()
cur.execute("INSERT INTO t values (3, 'payload')")
# ensure that the query above was hanging while acceptor was down
assert (time.time() - start) >= start_delay_sec
proc.join()
cur.execute("INSERT INTO t values (4, 'payload')")
cur.execute('SELECT sum(key) FROM t')
assert cur.fetchone() == (10, )
# shut down random subset of acceptors, sleep, wake them up, rinse, repeat
def xmas_garland(acceptors, stop):
while not bool(stop.value):

View File

@@ -146,9 +146,8 @@ async def run_restarts_under_load(env: NeonEnv,
max_transfer=100,
period_time=4,
iterations=10):
# Set timeout for this test at 5 minutes. It should be enough for test to complete
# and less than CircleCI's no_output_timeout, taking into account that this timeout
# is checked only at the beginning of every iteration.
# Set timeout for this test at 5 minutes. It should be enough for test to complete,
# taking into account that this timeout is checked only at the beginning of every iteration.
test_timeout_at = time.monotonic() + 5 * 60
pg_conn = await pg.connect_async()
@@ -404,3 +403,55 @@ def test_concurrent_computes(neon_env_builder: NeonEnvBuilder):
env.neon_cli.create_branch('test_concurrent_computes')
asyncio.run(run_concurrent_computes(env))
# Stop safekeeper and check that query cannot be executed while safekeeper is down.
# Query will insert a single row into a table.
async def check_unavailability(sk: Safekeeper,
conn: asyncpg.Connection,
key: int,
start_delay_sec: int = 2):
# shutdown one of two acceptors, that is, majority
sk.stop()
bg_query = asyncio.create_task(conn.execute(f"INSERT INTO t values ({key}, 'payload')"))
await asyncio.sleep(start_delay_sec)
# ensure that the query has not been executed yet
assert not bg_query.done()
# start safekeeper and await the query
sk.start()
await bg_query
assert bg_query.done()
async def run_unavailability(env: NeonEnv, pg: Postgres):
conn = await pg.connect_async()
# check basic work with table
await conn.execute('CREATE TABLE t(key int primary key, value text)')
await conn.execute("INSERT INTO t values (1, 'payload')")
# stop safekeeper and check that query cannot be executed while safekeeper is down
await check_unavailability(env.safekeepers[0], conn, 2)
# for the world's balance, do the same with second safekeeper
await check_unavailability(env.safekeepers[1], conn, 3)
# check that we can execute queries after restart
await conn.execute("INSERT INTO t values (4, 'payload')")
result_sum = await conn.fetchval('SELECT sum(key) FROM t')
assert result_sum == 10
# When majority of acceptors is offline, commits are expected to be frozen
def test_unavailability(neon_env_builder: NeonEnvBuilder):
neon_env_builder.num_safekeepers = 2
env = neon_env_builder.init_start()
env.neon_cli.create_branch('test_safekeepers_unavailability')
pg = env.postgres.create_start('test_safekeepers_unavailability')
asyncio.run(run_unavailability(env, pg))

View File

@@ -1,5 +1,5 @@
pytest_plugins = ("fixtures.neon_fixtures",
"fixtures.benchmark_fixture",
"fixtures.pg_stats",
"fixtures.compare_fixtures",
"fixtures.slow",
"fixtures.pg_stats")
"fixtures.slow")

View File

@@ -30,7 +30,7 @@ from dataclasses import dataclass
# Type-related stuff
from psycopg2.extensions import connection as PgConnection
from psycopg2.extensions import make_dsn, parse_dsn
from typing import Any, Callable, Dict, Iterator, List, Optional, Type, TypeVar, cast, Union, Tuple
from typing import Any, Callable, Dict, Iterator, List, Optional, TypeVar, cast, Union, Tuple
from typing_extensions import Literal
import requests
@@ -280,20 +280,18 @@ class PgProtocol:
return str(make_dsn(**self.conn_options(**kwargs)))
def conn_options(self, **kwargs):
conn_options = self.default_options.copy()
result = self.default_options.copy()
if 'dsn' in kwargs:
conn_options.update(parse_dsn(kwargs['dsn']))
conn_options.update(kwargs)
result.update(parse_dsn(kwargs['dsn']))
result.update(kwargs)
# Individual statement timeout in seconds. 2 minutes should be
# enough for our tests, but if you need a longer, you can
# change it by calling "SET statement_timeout" after
# connecting.
if 'options' in conn_options:
conn_options['options'] = f"-cstatement_timeout=120s " + conn_options['options']
else:
conn_options['options'] = "-cstatement_timeout=120s"
return conn_options
options = result.get('options', '')
result['options'] = f'-cstatement_timeout=120s {options}'
return result
# autocommit=True here by default because that's what we need most of the time
def connect(self, autocommit=True, **kwargs) -> PgConnection:
@@ -693,6 +691,10 @@ class NeonEnv:
""" Get list of safekeeper endpoints suitable for safekeepers GUC """
return ','.join([f'localhost:{wa.port.pg}' for wa in self.safekeepers])
def timeline_dir(self, tenant_id: uuid.UUID, timeline_id: uuid.UUID) -> Path:
"""Get a timeline directory's path based on the repo directory of the test environment"""
return self.repo_dir / "tenants" / tenant_id.hex / "timelines" / timeline_id.hex
@cached_property
def auth_keys(self) -> AuthKeys:
pub = (Path(self.repo_dir) / 'auth_public_key.pem').read_bytes()
@@ -865,8 +867,8 @@ class NeonPageserverHttpClient(requests.Session):
def timeline_detail(self, tenant_id: uuid.UUID, timeline_id: uuid.UUID) -> Dict[Any, Any]:
res = self.get(
f"http://localhost:{self.port}/v1/tenant/{tenant_id.hex}/timeline/{timeline_id.hex}?include-non-incremental-logical-size=1"
)
f"http://localhost:{self.port}/v1/tenant/{tenant_id.hex}/timeline/{timeline_id.hex}" +
"?include-non-incremental-logical-size=1&include-non-incremental-physical-size=1")
self.verbose_error(res)
res_json = res.json()
assert isinstance(res_json, dict)
@@ -1514,29 +1516,25 @@ def remote_pg(test_output_dir: Path) -> Iterator[RemotePostgres]:
class NeonProxy(PgProtocol):
def __init__(self, port: int, pg_port: int):
super().__init__(host="127.0.0.1",
user="proxy_user",
password="pytest2",
port=port,
dbname='postgres')
self.http_port = 7001
self.host = "127.0.0.1"
self.port = port
self.pg_port = pg_port
def __init__(self, proxy_port: int, http_port: int, auth_endpoint: str):
super().__init__(dsn=auth_endpoint, port=proxy_port)
self.host = '127.0.0.1'
self.http_port = http_port
self.proxy_port = proxy_port
self.auth_endpoint = auth_endpoint
self._popen: Optional[subprocess.Popen[bytes]] = None
def start(self) -> None:
assert self._popen is None
# Start proxy
bin_proxy = os.path.join(str(neon_binpath), 'proxy')
args = [bin_proxy]
args.extend(["--http", f"{self.host}:{self.http_port}"])
args.extend(["--proxy", f"{self.host}:{self.port}"])
args.extend(["--auth-backend", "postgres"])
args.extend(
["--auth-endpoint", f"postgres://proxy_auth:pytest1@localhost:{self.pg_port}/postgres"])
args = [
os.path.join(str(neon_binpath), 'proxy'),
*["--http", f"{self.host}:{self.http_port}"],
*["--proxy", f"{self.host}:{self.proxy_port}"],
*["--auth-backend", "postgres"],
*["--auth-endpoint", self.auth_endpoint],
]
self._popen = subprocess.Popen(args)
self._wait_until_ready()
@@ -1557,13 +1555,21 @@ class NeonProxy(PgProtocol):
@pytest.fixture(scope='function')
def static_proxy(vanilla_pg, port_distributor) -> Iterator[NeonProxy]:
"""Neon proxy that routes directly to vanilla postgres."""
vanilla_pg.start()
vanilla_pg.safe_psql("create user proxy_auth with password 'pytest1' superuser")
vanilla_pg.safe_psql("create user proxy_user with password 'pytest2'")
port = port_distributor.get_port()
pg_port = vanilla_pg.default_options['port']
with NeonProxy(port, pg_port) as proxy:
# For simplicity, we use the same user for both `--auth-endpoint` and `safe_psql`
vanilla_pg.start()
vanilla_pg.safe_psql("create user proxy with login superuser password 'password'")
port = vanilla_pg.default_options['port']
host = vanilla_pg.default_options['host']
dbname = vanilla_pg.default_options['dbname']
auth_endpoint = f'postgres://proxy:password@{host}:{port}/{dbname}'
proxy_port = port_distributor.get_port()
http_port = port_distributor.get_port()
with NeonProxy(proxy_port=proxy_port, http_port=http_port,
auth_endpoint=auth_endpoint) as proxy:
proxy.start()
yield proxy
@@ -1923,7 +1929,7 @@ class SafekeeperHttpClient(requests.Session):
self.get(f"http://localhost:{self.port}/v1/status").raise_for_status()
def timeline_status(self, tenant_id: str, timeline_id: str) -> SafekeeperTimelineStatus:
res = self.get(f"http://localhost:{self.port}/v1/timeline/{tenant_id}/{timeline_id}")
res = self.get(f"http://localhost:{self.port}/v1/tenant/{tenant_id}/timeline/{timeline_id}")
res.raise_for_status()
resj = res.json()
return SafekeeperTimelineStatus(acceptor_epoch=resj['acceptor_state']['epoch'],

View File

@@ -1,9 +1,11 @@
import contextlib
import os
import pathlib
import shutil
import subprocess
from pathlib import Path
from typing import Any, List
from typing import Any, List, Tuple
from fixtures.log_helper import log
@@ -89,3 +91,36 @@ def get_dir_size(path: str) -> int:
pass # file could be concurrently removed
return totalbytes
def get_timeline_dir_size(path: pathlib.Path) -> int:
"""Get the timeline directory's total size, which only counts the layer files' size."""
sz = 0
for dir_entry in path.iterdir():
with contextlib.suppress(Exception):
# file is an image layer
_ = parse_image_layer(dir_entry.name)
sz += dir_entry.stat().st_size
continue
with contextlib.suppress(Exception):
# file is a delta layer
_ = parse_delta_layer(dir_entry.name)
sz += dir_entry.stat().st_size
continue
return sz
def parse_image_layer(f_name: str) -> Tuple[int, int, int]:
"""Parse an image layer file name. Return key start, key end, and snapshot lsn"""
parts = f_name.split("__")
key_parts = parts[0].split("-")
return int(key_parts[0], 16), int(key_parts[1], 16), int(parts[1], 16)
def parse_delta_layer(f_name: str) -> Tuple[int, int, int, int]:
"""Parse a delta layer file name. Return key start, key end, lsn start, and lsn end"""
parts = f_name.split("__")
key_parts = parts[0].split("-")
lsn_parts = parts[1].split("-")
return int(key_parts[0], 16), int(key_parts[1], 16), int(lsn_parts[0], 16), int(lsn_parts[1], 16)

View File

@@ -0,0 +1,110 @@
import random
import time
import statistics
import threading
import timeit
import pytest
from typing import List
from fixtures.benchmark_fixture import MetricReport
from fixtures.compare_fixtures import NeonCompare
from fixtures.log_helper import log
def _record_branch_creation_durations(neon_compare: NeonCompare, durs: List[float]):
neon_compare.zenbenchmark.record("branch_creation_duration_max",
max(durs),
's',
MetricReport.LOWER_IS_BETTER)
neon_compare.zenbenchmark.record("branch_creation_duration_avg",
statistics.mean(durs),
's',
MetricReport.LOWER_IS_BETTER)
neon_compare.zenbenchmark.record("branch_creation_duration_stdev",
statistics.stdev(durs),
's',
MetricReport.LOWER_IS_BETTER)
@pytest.mark.parametrize("n_branches", [20])
# Test measures the latency of branch creation during a heavy [1] workload.
#
# [1]: to simulate a heavy workload, the test tweaks the GC and compaction settings
# to increase the task's frequency. The test runs `pgbench` in each new branch.
# Each branch is created from a randomly picked source branch.
def test_branch_creation_heavy_write(neon_compare: NeonCompare, n_branches: int):
env = neon_compare.env
pg_bin = neon_compare.pg_bin
# Use aggressive GC and checkpoint settings, so GC and compaction happen more often during the test
tenant, _ = env.neon_cli.create_tenant(
conf={
'gc_period': '5 s',
'gc_horizon': f'{4 * 1024 ** 2}',
'checkpoint_distance': f'{2 * 1024 ** 2}',
'compaction_target_size': f'{1024 ** 2}',
'compaction_threshold': '2',
# set PITR interval to be small, so we can do GC
'pitr_interval': '5 s'
})
def run_pgbench(branch: str):
log.info(f"Start a pgbench workload on branch {branch}")
pg = env.postgres.create_start(branch, tenant_id=tenant)
connstr = pg.connstr()
pg_bin.run_capture(['pgbench', '-i', connstr])
pg_bin.run_capture(['pgbench', '-c10', '-T10', connstr])
pg.stop()
env.neon_cli.create_branch('b0', tenant_id=tenant)
threads: List[threading.Thread] = []
threads.append(threading.Thread(target=run_pgbench, args=('b0', ), daemon=True))
threads[-1].start()
branch_creation_durations = []
for i in range(n_branches):
time.sleep(1.0)
# random a source branch
p = random.randint(0, i)
timer = timeit.default_timer()
env.neon_cli.create_branch('b{}'.format(i + 1), 'b{}'.format(p), tenant_id=tenant)
dur = timeit.default_timer() - timer
log.info(f"Creating branch b{i+1} took {dur}s")
branch_creation_durations.append(dur)
threads.append(threading.Thread(target=run_pgbench, args=(f'b{i+1}', ), daemon=True))
threads[-1].start()
for thread in threads:
thread.join()
_record_branch_creation_durations(neon_compare, branch_creation_durations)
@pytest.mark.parametrize("n_branches", [1024])
# Test measures the latency of branch creation when creating a lot of branches.
def test_branch_creation_many(neon_compare: NeonCompare, n_branches: int):
env = neon_compare.env
env.neon_cli.create_branch('b0')
pg = env.postgres.create_start('b0')
neon_compare.pg_bin.run_capture(['pgbench', '-i', '-s10', pg.connstr()])
branch_creation_durations = []
for i in range(n_branches):
# random a source branch
p = random.randint(0, i)
timer = timeit.default_timer()
env.neon_cli.create_branch('b{}'.format(i + 1), 'b{}'.format(p))
dur = timeit.default_timer() - timer
branch_creation_durations.append(dur)
_record_branch_creation_durations(neon_compare, branch_creation_durations)

View File

@@ -1,4 +1,6 @@
import os
import threading
import time
from typing import List
import pytest
@@ -87,3 +89,34 @@ def test_compare_pg_stats_wal_with_pgbench_default(neon_with_baseline: PgCompare
env.pg_bin.run_capture(
['pgbench', f'-T{duration}', f'--random-seed={seed}', env.pg.connstr()])
env.flush()
@pytest.mark.parametrize("n_tables", [1, 10])
@pytest.mark.parametrize("duration", get_durations_matrix(10))
def test_compare_pg_stats_wo_with_heavy_write(neon_with_baseline: PgCompare,
n_tables: int,
duration: int,
pg_stats_wo: List[PgStatTable]):
env = neon_with_baseline
with env.pg.connect().cursor() as cur:
for i in range(n_tables):
cur.execute(
f"CREATE TABLE t{i}(key serial primary key, t text default 'foooooooooooooooooooooooooooooooooooooooooooooooooooo')"
)
def start_single_table_workload(table_id: int):
start = time.time()
with env.pg.connect().cursor() as cur:
while time.time() - start < duration:
cur.execute(f"INSERT INTO t{table_id} SELECT FROM generate_series(1,1000)")
with env.record_pg_stats(pg_stats_wo):
threads = [
threading.Thread(target=start_single_table_workload, args=(i, ))
for i in range(n_tables)
]
for thread in threads:
thread.start()
for thread in threads:
thread.join()

View File

@@ -0,0 +1,50 @@
import pytest
from contextlib import closing
from fixtures.compare_fixtures import PgCompare
from pytest_lazyfixture import lazy_fixture # type: ignore
@pytest.mark.parametrize(
"env",
[
# The test is too slow to run in CI, but fast enough to run with remote tests
pytest.param(lazy_fixture("neon_compare"), id="neon", marks=pytest.mark.slow),
pytest.param(lazy_fixture("vanilla_compare"), id="vanilla", marks=pytest.mark.slow),
pytest.param(lazy_fixture("remote_compare"), id="remote", marks=pytest.mark.remote_cluster),
])
def test_dup_key(env: PgCompare):
# Update the same page many times, then measure read performance
with closing(env.pg.connect()) as conn:
with conn.cursor() as cur:
cur.execute('drop table if exists t, f;')
cur.execute("SET synchronous_commit=off")
cur.execute("SET statement_timeout=0")
# Write many updates to the same row
with env.record_duration('write'):
cur.execute("create table t (i integer, filler text);")
cur.execute('insert into t values (0);')
cur.execute("""
do $$
begin
for ivar in 1..5000000 loop
update t set i = ivar, filler = repeat('a', 50);
update t set i = ivar, filler = repeat('b', 50);
update t set i = ivar, filler = repeat('c', 50);
update t set i = ivar, filler = repeat('d', 50);
rollback;
end loop;
end;
$$;
""")
# Write 3-4 MB to evict t from compute cache
cur.execute('create table f (i integer);')
cur.execute(f'insert into f values (generate_series(1,100000));')
# Read
with env.record_duration('read'):
cur.execute('select * from t;')
cur.fetchall()