Compare commits

...

27 Commits

Author SHA1 Message Date
Thang Pham
2a15415442 use parking_lot::RwLock in for page caches 2022-06-29 12:10:52 -04:00
Kirill Bulatov
4a05413a4c More code coverage fixes in GH Actions (#2002) 2022-06-27 22:40:20 +03:00
Kirill Bulatov
dd61f3558f Fix coverage upload credentials retrieval (#2001) 2022-06-27 20:41:09 +03:00
Kirill Bulatov
8a714f1ebf Add coverage to GH actions and rework part of them (#1987) 2022-06-27 19:15:56 +03:00
Arseny Sher
137291dc24 Push to etcd from safekeeper many timelines concurrently.
Mitigates latency fee, making push throughput 1-1.5 order of magnitude bigger.

Also make leases per timeline, not per whole safekeeper, avoiding storing
garbage in etcd for deleted timelines while safekeeper is alive.
2022-06-27 16:30:21 +03:00
Kirill Bulatov
eb8926083e Use the updated base build Docker image (#1972) 2022-06-27 13:12:58 +03:00
Johan Eliasson
26bca6ddba Add openssl to OSX dependencies (#1994) 2022-06-26 21:54:07 +03:00
Arthur Petukhovsky
55192384c3 Fix zero timeline_start_lsn (#1981)
* Fix zero timeline_start_lsn

* Log more info on control file upgrade

* Fix formatting

Co-authored-by: Anastasia Lubennikova <anastasia@neon.tech>
2022-06-24 13:59:37 +03:00
KlimentSerafimov
392cd8b1fc Refactored extracting project_name in console.rs. (#1982) 2022-06-24 05:57:33 -04:00
Alexey Kondratov
3cc531d093 Fix CREATE EXTENSION for non-db-owner users (#1408)
Previously, we were granting create only to db owner, but now we have a
dedicated 'web_access' role to connect via web UI and proxy link auth.

We anyway grant read / write all data to all roles, so let's grant
create to everyone too. This creates some provelege objects in each db,
which we need to drop before deleting the role. So now we reassign all
owned objects to each db owner before deletion. This also fixes deletion
of roles that created some data in any db previously. Will be tested by
https://github.com/neondatabase/cloud/pull/1673

Later we should stop messing with Postgres ACL that much.
2022-06-23 21:36:53 +02:00
bojanserafimov
84b9fcbbd5 Increase a few test timeouts (#1977) 2022-06-23 11:51:56 -04:00
Bojan Serafimov
93e050afe3 Don't require project name for link auth 2022-06-23 15:38:05 +03:00
Anastasia Lubennikova
6d7dc384a5 Add zenith-us-stage-ps-3 to deploy 2022-06-23 14:52:32 +03:00
Anastasia Lubennikova
3c2b03cd87 Update timeline size on dropdb. Add the test (#1973)
In addition, fix database size calculation:
count not only main fork of the relation, but also vm and fsm.
2022-06-23 12:28:12 +03:00
Kirill Bulatov
7c49abe7d1 Rework etcd timeline updates and their handling 2022-06-23 09:11:27 +03:00
KlimentSerafimov
d059e588a6 Added invariant check for project name. (#1921)
Summary: Added invariant checking for project name. Refactored ClientCredentials and TlsConfig.

* Added formatting invariant check for project name:
**\forall c \in project_name . c \in [alnum] U {'-'}. 
** sni_data == <project_name>.<common_name>
* Added exhaustive tests for get_project_name.
* Refactored TlsConfig to contain common_name : Option<String>.
* Refactored ClientCredentials construction to construct project_name directly.
* Merged ProjectNameError into ClientCredsParseError.
* Tweaked proxy tests to accommodate refactored ClientCredentials construction semantics. 
* [Pytests] Added project option argument to test_proxy_select_1.
* Removed project param from Api since now it's contained in creds.
* Refactored &Option<String> -> Option<&str>.

Co-authored-by: Dmitrii Ivanov <dima@neon.tech>.
2022-06-22 09:34:24 -04:00
Sergey Melnikov
6222a0012b Migrate from CircleCI to Github Actions: python codestyle, build and tests (#1647)
Duplicate postgres and neon build and test jobs from CircleCI to Github actions.
2022-06-22 11:40:59 +03:00
bojanserafimov
1ca28e6f3c Import basebackup into pageserver (#1925)
Allow importing basebackup taken from vanilla postgres or another pageserver via psql copy in protocol.
2022-06-21 11:04:10 -04:00
Arthur Petukhovsky
6c4d6a2183 Remove timeline_start_lsn check temporary. (#1964) 2022-06-21 02:02:24 +03:00
Thang Pham
37465dafe3 Add wal backpressure tests (#1919)
Resolves #1889.

This PR adds new tests to measure the WAL backpressure's performance under different workloads.

## Changes
- add new performance tests in `test_wal_backpressure.py`
- allow safekeeper's fsync to be configurable when running tests
2022-06-20 11:40:55 -04:00
Joshua D. Drake
ec0064c442 Small README.md changes (#1957)
* Update make instructions for release and debug build. Update PostgreSQL glossary to proper version (14)

* Continued cleanup of build instructions including removal of redundancies
2022-06-20 10:05:10 -04:00
Heikki Linnakangas
83c7e6ce52 Bump vendor/postgres.
This brings in the change to not use a shared memory in the WAL redo
process, to avoid running out of sysv shmem segments in the page server.

Also, removal of callmemaybe bits.
2022-06-20 15:28:43 +03:00
Arthur Petukhovsky
f862373ac0 Fix WAL timeout in test_s3_wal_replay (#1953) 2022-06-17 20:43:54 +03:00
Arthur Petukhovsky
699f46cd84 Download WAL from S3 if it's not available in safekeeper dir (#1932)
`send_wal.rs` and `WalReader` are now async. `test_s3_wal_replay` checks that WAL can be replayed after offloaded.
2022-06-17 15:33:39 +03:00
Anastasia Lubennikova
36ee182d26 Implement page servise 'fullbackup' endpoint (#1923)
* Implement page servise 'fullbackup' endpoint that works like basebackup, but also sends relational files

* Add test_runner/batch_others/test_fullbackup.py

Co-authored-by: bojanserafimov <bojan.serafimov7@gmail.com>
2022-06-16 14:07:11 +03:00
Anastasia Lubennikova
d11c9f9fcb Use random ports for the proxy and local pg in tests
Fixes #1931
Author: Dmitry Ivanov
2022-06-15 20:21:58 +03:00
Kirill Bulatov
d8a37452c8 Rename ZenithFeedback (#1912) 2022-06-11 00:44:05 +03:00
55 changed files with 4510 additions and 1841 deletions

View File

@@ -1,6 +1,7 @@
[pageservers]
#zenith-us-stage-ps-1 console_region_id=27
zenith-us-stage-ps-2 console_region_id=27
zenith-us-stage-ps-3 console_region_id=27
[safekeepers]
zenith-us-stage-sk-4 console_region_id=27

View File

@@ -286,7 +286,7 @@ jobs:
# no_output_timeout, specified here.
no_output_timeout: 10m
environment:
- ZENITH_BIN: /tmp/zenith/bin
- NEON_BIN: /tmp/zenith/bin
- POSTGRES_DISTRIB_DIR: /tmp/zenith/pg_install
- TEST_OUTPUT: /tmp/test_output
# this variable will be embedded in perf test report
@@ -688,50 +688,6 @@ jobs:
helm upgrade neon-proxy neondatabase/neon-proxy --install -f .circleci/helm-values/production.proxy.yaml --set image.tag=${DOCKER_TAG} --wait
helm upgrade neon-proxy-scram neondatabase/neon-proxy --install -f .circleci/helm-values/production.proxy-scram.yaml --set image.tag=${DOCKER_TAG} --wait
# Trigger a new remote CI job
remote-ci-trigger:
docker:
- image: cimg/base:2021.04
parameters:
remote_repo:
type: string
environment:
REMOTE_REPO: << parameters.remote_repo >>
steps:
- run:
name: Set PR's status to pending
command: |
LOCAL_REPO=$CIRCLE_PROJECT_USERNAME/$CIRCLE_PROJECT_REPONAME
curl -f -X POST \
https://api.github.com/repos/$LOCAL_REPO/statuses/$CIRCLE_SHA1 \
-H "Accept: application/vnd.github.v3+json" \
--user "$CI_ACCESS_TOKEN" \
--data \
"{
\"state\": \"pending\",
\"context\": \"neon-cloud-e2e\",
\"description\": \"[$REMOTE_REPO] Remote CI job is about to start\"
}"
- run:
name: Request a remote CI test
command: |
LOCAL_REPO=$CIRCLE_PROJECT_USERNAME/$CIRCLE_PROJECT_REPONAME
curl -f -X POST \
https://api.github.com/repos/$REMOTE_REPO/actions/workflows/testing.yml/dispatches \
-H "Accept: application/vnd.github.v3+json" \
--user "$CI_ACCESS_TOKEN" \
--data \
"{
\"ref\": \"main\",
\"inputs\": {
\"ci_job_name\": \"neon-cloud-e2e\",
\"commit_hash\": \"$CIRCLE_SHA1\",
\"remote_repo\": \"$LOCAL_REPO\"
}
}"
workflows:
build_and_test:
jobs:
@@ -880,14 +836,3 @@ workflows:
- release
requires:
- docker-image-release
- remote-ci-trigger:
# Context passes credentials for gh api
context: CI_ACCESS_TOKEN
remote_repo: "neondatabase/cloud"
requires:
# XXX: Successful build doesn't mean everything is OK, but
# the job to be triggered takes so much time to complete (~22 min)
# that it's better not to wait for the commented-out steps
- build-neon-release
# - pg_regress-tests-release
# - other-tests-release

View File

@@ -0,0 +1,140 @@
name: 'Run python test'
description: 'Runs a Neon python test set, performing all the required preparations before'
inputs:
build_type:
description: 'Type of Rust (neon) and C (postgres) builds. Must be "release" or "debug".'
required: true
rust_toolchain:
description: 'Rust toolchain version to fetch the caches'
required: true
test_selection:
description: 'A python test suite to run'
required: true
extra_params:
description: 'Arbitrary parameters to pytest. For example "-s" to prevent capturing stdout/stderr'
required: false
default: ''
needs_postgres_source:
description: 'Set to true if the test suite requires postgres source checked out'
required: false
default: 'false'
run_in_parallel:
description: 'Whether to run tests in parallel'
required: false
default: 'true'
save_perf_report:
description: 'Whether to upload the performance report'
required: false
default: 'false'
runs:
using: "composite"
steps:
- name: Get Neon artifact for restoration
uses: actions/download-artifact@v3
with:
name: neon-${{ runner.os }}-${{ inputs.build_type }}-${{ inputs.rust_toolchain }}-artifact
path: ./neon-artifact/
- name: Extract Neon artifact
shell: bash -ex {0}
run: |
mkdir -p /tmp/neon/
tar -xf ./neon-artifact/neon.tgz -C /tmp/neon/
rm -rf ./neon-artifact/
- name: Checkout
if: inputs.needs_postgres_source == 'true'
uses: actions/checkout@v3
with:
submodules: true
fetch-depth: 1
- name: Cache poetry deps
id: cache_poetry
uses: actions/cache@v3
with:
path: ~/.cache/pypoetry/virtualenvs
key: v1-${{ runner.os }}-python-deps-${{ hashFiles('poetry.lock') }}
- name: Install Python deps
shell: bash -ex {0}
run: ./scripts/pysync
- name: Run pytest
env:
NEON_BIN: /tmp/neon/bin
POSTGRES_DISTRIB_DIR: /tmp/neon/pg_install
TEST_OUTPUT: /tmp/test_output
# this variable will be embedded in perf test report
# and is needed to distinguish different environments
PLATFORM: github-actions-selfhosted
shell: bash -ex {0}
run: |
PERF_REPORT_DIR="$(realpath test_runner/perf-report-local)"
rm -rf $PERF_REPORT_DIR
TEST_SELECTION="test_runner/${{ inputs.test_selection }}"
EXTRA_PARAMS="${{ inputs.extra_params }}"
if [ -z "$TEST_SELECTION" ]; then
echo "test_selection must be set"
exit 1
fi
if [[ "${{ inputs.run_in_parallel }}" == "true" ]]; then
EXTRA_PARAMS="-n4 $EXTRA_PARAMS"
fi
if [[ "${{ inputs.save_perf_report }}" == "true" ]]; then
if [[ "$GITHUB_REF" == "main" ]]; then
mkdir -p "$PERF_REPORT_DIR"
EXTRA_PARAMS="--out-dir $PERF_REPORT_DIR $EXTRA_PARAMS"
fi
fi
if [[ "${{ inputs.build_type }}" == "debug" ]]; then
cov_prefix=(scripts/coverage "--profraw-prefix=$GITHUB_JOB" --dir=/tmp/neon/coverage run)
elif [[ "${{ inputs.build_type }}" == "release" ]]; then
cov_prefix=()
fi
# Run the tests.
#
# The junit.xml file allows CircleCI to display more fine-grained test information
# in its "Tests" tab in the results page.
# --verbose prints name of each test (helpful when there are
# multiple tests in one file)
# -rA prints summary in the end
# -n4 uses four processes to run tests via pytest-xdist
# -s is not used to prevent pytest from capturing output, because tests are running
# in parallel and logs are mixed between different tests
"${cov_prefix[@]}" ./scripts/pytest \
--junitxml=$TEST_OUTPUT/junit.xml \
--tb=short \
--verbose \
-m "not remote_cluster" \
-rA $TEST_SELECTION $EXTRA_PARAMS
if [[ "${{ inputs.save_perf_report }}" == "true" ]]; then
if [[ "$GITHUB_REF" == "main" ]]; then
export REPORT_FROM="$PERF_REPORT_DIR"
export REPORT_TO=local
scripts/generate_and_push_perf_report.sh
fi
fi
- name: Delete all data but logs
shell: bash -ex {0}
if: always()
run: |
du -sh /tmp/test_output/*
find /tmp/test_output -type f ! -name "*.log" ! -name "regression.diffs" ! -name "junit.xml" ! -name "*.filediff" ! -name "*.stdout" ! -name "*.stderr" ! -name "flamegraph.svg" ! -name "*.metrics" -delete
du -sh /tmp/test_output/*
- name: Upload python test logs
if: always()
uses: actions/upload-artifact@v3
with:
retention-days: 7
if-no-files-found: error
name: python-test-${{ inputs.test_selection }}-${{ runner.os }}-${{ inputs.build_type }}-${{ inputs.rust_toolchain }}-logs
path: /tmp/test_output/

View File

@@ -0,0 +1,17 @@
name: 'Merge and upload coverage data'
description: 'Compresses and uploads the coverage data as an artifact'
runs:
using: "composite"
steps:
- name: Merge coverage data
shell: bash -ex {0}
run: scripts/coverage "--profraw-prefix=$GITHUB_JOB" --dir=/tmp/neon/coverage/ merge
- name: Upload coverage data
uses: actions/upload-artifact@v3
with:
retention-days: 7
if-no-files-found: error
name: coverage-data-artifact
path: /tmp/neon/coverage/

387
.github/workflows/build_and_test.yml vendored Normal file
View File

@@ -0,0 +1,387 @@
name: Test
on:
push:
branches:
- main
pull_request:
defaults:
run:
shell: bash -ex {0}
concurrency:
group: ${{ github.workflow }}-${{ github.ref }}
cancel-in-progress: true
env:
RUST_BACKTRACE: 1
COPT: '-Werror'
AWS_ACCESS_KEY_ID: ${{ secrets.CACHEPOT_AWS_ACCESS_KEY_ID }}
AWS_SECRET_ACCESS_KEY: ${{ secrets.CACHEPOT_AWS_SECRET_ACCESS_KEY }}
CACHEPOT_BUCKET: zenith-rust-cachepot
RUSTC_WRAPPER: cachepot
jobs:
build-postgres:
runs-on: [ self-hosted, Linux, k8s-runner ]
strategy:
fail-fast: false
matrix:
build_type: [ debug, release ]
rust_toolchain: [ 1.58 ]
env:
BUILD_TYPE: ${{ matrix.build_type }}
steps:
- name: Checkout
uses: actions/checkout@v3
with:
submodules: true
fetch-depth: 1
- name: Set pg revision for caching
id: pg_ver
run: echo ::set-output name=pg_rev::$(git rev-parse HEAD:vendor/postgres)
- name: Cache postgres build
id: cache_pg
uses: actions/cache@v3
with:
path: tmp_install/
key: v1-${{ runner.os }}-${{ matrix.build_type }}-pg-${{ steps.pg_ver.outputs.pg_rev }}-${{ hashFiles('Makefile') }}
- name: Build postgres
if: steps.cache_pg.outputs.cache-hit != 'true'
run: COPT='-Werror' mold -run make postgres -j$(nproc)
# actions/cache@v3 does not allow concurrently using the same cache across job steps, so use a separate cache
- name: Prepare postgres artifact
run: tar -C tmp_install/ -czf ./pg.tgz .
- name: Upload postgres artifact
uses: actions/upload-artifact@v3
with:
retention-days: 7
if-no-files-found: error
name: postgres-${{ runner.os }}-${{ matrix.build_type }}-artifact
path: ./pg.tgz
build-neon:
runs-on: [ self-hosted, Linux, k8s-runner ]
needs: [ build-postgres ]
strategy:
fail-fast: false
matrix:
build_type: [ debug, release ]
rust_toolchain: [ 1.58 ]
env:
BUILD_TYPE: ${{ matrix.build_type }}
steps:
- name: Checkout
uses: actions/checkout@v3
with:
submodules: true
fetch-depth: 1
- name: Get postgres artifact for restoration
uses: actions/download-artifact@v3
with:
name: postgres-${{ runner.os }}-${{ matrix.build_type }}-artifact
path: ./postgres-artifact/
- name: Extract postgres artifact
run: |
mkdir ./tmp_install/
tar -xf ./postgres-artifact/pg.tgz -C ./tmp_install/
rm -rf ./postgres-artifact/
- name: Cache cargo deps
id: cache_cargo
uses: actions/cache@v3
with:
path: |
~/.cargo/registry/
~/.cargo/git/
target/
key: v2-${{ runner.os }}-${{ matrix.build_type }}-cargo-${{ matrix.rust_toolchain }}-${{ hashFiles('Cargo.lock') }}
- name: Run cargo build
run: |
if [[ $BUILD_TYPE == "debug" ]]; then
cov_prefix=(scripts/coverage "--profraw-prefix=$GITHUB_JOB" --dir=/tmp/neon/coverage run)
CARGO_FLAGS=
elif [[ $BUILD_TYPE == "release" ]]; then
cov_prefix=()
CARGO_FLAGS="--release --features profiling"
fi
"${cov_prefix[@]}" mold -run cargo build $CARGO_FLAGS --features failpoints --bins --tests
cachepot -s
- name: Run cargo test
run: |
if [[ $BUILD_TYPE == "debug" ]]; then
cov_prefix=(scripts/coverage "--profraw-prefix=$GITHUB_JOB" --dir=/tmp/neon/coverage run)
CARGO_FLAGS=
elif [[ $BUILD_TYPE == "release" ]]; then
cov_prefix=()
CARGO_FLAGS=--release
fi
"${cov_prefix[@]}" cargo test $CARGO_FLAGS
- name: Install rust binaries
run: |
if [[ $BUILD_TYPE == "debug" ]]; then
cov_prefix=(scripts/coverage "--profraw-prefix=$GITHUB_JOB" --dir=/tmp/neon/coverage run)
elif [[ $BUILD_TYPE == "release" ]]; then
cov_prefix=()
fi
binaries=$(
"${cov_prefix[@]}" cargo metadata --format-version=1 --no-deps |
jq -r '.packages[].targets[] | select(.kind | index("bin")) | .name'
)
test_exe_paths=$(
"${cov_prefix[@]}" cargo test --message-format=json --no-run |
jq -r '.executable | select(. != null)'
)
mkdir -p /tmp/neon/bin/
mkdir -p /tmp/neon/test_bin/
mkdir -p /tmp/neon/etc/
mkdir -p /tmp/neon/coverage/
# Install target binaries
for bin in $binaries; do
SRC=target/$BUILD_TYPE/$bin
DST=/tmp/neon/bin/$bin
cp "$SRC" "$DST"
done
# Install test executables and write list of all binaries (for code coverage)
if [[ $BUILD_TYPE == "debug" ]]; then
for bin in $binaries; do
echo "/tmp/neon/bin/$bin" >> /tmp/neon/coverage/binaries.list
done
for bin in $test_exe_paths; do
SRC=$bin
DST=/tmp/neon/test_bin/$(basename $bin)
cp "$SRC" "$DST"
echo "$DST" >> /tmp/neon/coverage/binaries.list
done
fi
- name: Install postgres binaries
run: cp -a tmp_install /tmp/neon/pg_install
- name: Prepare neon artifact
run: tar -C /tmp/neon/ -czf ./neon.tgz .
- name: Upload neon binaries
uses: actions/upload-artifact@v3
with:
retention-days: 7
if-no-files-found: error
name: neon-${{ runner.os }}-${{ matrix.build_type }}-${{ matrix.rust_toolchain }}-artifact
path: ./neon.tgz
# XXX: keep this after the binaries.list is formed, so the coverage can properly work later
- name: Merge and upload coverage data
if: matrix.build_type == 'debug'
uses: ./.github/actions/save-coverage-data
pg_regress-tests:
runs-on: [ self-hosted, Linux, k8s-runner ]
needs: [ build-neon ]
strategy:
fail-fast: false
matrix:
build_type: [ debug, release ]
rust_toolchain: [ 1.58 ]
steps:
- name: Checkout
uses: actions/checkout@v3
with:
submodules: true
fetch-depth: 2
- name: Pytest regress tests
uses: ./.github/actions/run-python-test-set
with:
build_type: ${{ matrix.build_type }}
rust_toolchain: ${{ matrix.rust_toolchain }}
test_selection: batch_pg_regress
needs_postgres_source: true
- name: Merge and upload coverage data
if: matrix.build_type == 'debug'
uses: ./.github/actions/save-coverage-data
other-tests:
runs-on: [ self-hosted, Linux, k8s-runner ]
needs: [ build-neon ]
strategy:
fail-fast: false
matrix:
build_type: [ debug, release ]
rust_toolchain: [ 1.58 ]
steps:
- name: Checkout
uses: actions/checkout@v3
with:
submodules: true
fetch-depth: 2
- name: Pytest other tests
uses: ./.github/actions/run-python-test-set
with:
build_type: ${{ matrix.build_type }}
rust_toolchain: ${{ matrix.rust_toolchain }}
test_selection: batch_others
- name: Merge and upload coverage data
if: matrix.build_type == 'debug'
uses: ./.github/actions/save-coverage-data
benchmarks:
runs-on: [ self-hosted, Linux, k8s-runner ]
needs: [ build-neon ]
strategy:
fail-fast: false
matrix:
build_type: [ release ]
rust_toolchain: [ 1.58 ]
steps:
- name: Checkout
uses: actions/checkout@v3
with:
submodules: true
fetch-depth: 2
- name: Pytest benchmarks
uses: ./.github/actions/run-python-test-set
with:
build_type: ${{ matrix.build_type }}
rust_toolchain: ${{ matrix.rust_toolchain }}
test_selection: performance
run_in_parallel: false
save_perf_report: true
# XXX: no coverage data handling here, since benchmarks are run on release builds,
# while coverage is currently collected for the debug ones
coverage-report:
runs-on: [ self-hosted, Linux, k8s-runner ]
needs: [ other-tests, pg_regress-tests ]
strategy:
fail-fast: false
matrix:
build_type: [ debug ]
rust_toolchain: [ 1.58 ]
steps:
- name: Checkout
uses: actions/checkout@v3
with:
submodules: true
fetch-depth: 1
- name: Restore cargo deps cache
id: cache_cargo
uses: actions/cache@v3
with:
path: |
~/.cargo/registry/
~/.cargo/git/
target/
key: v2-${{ runner.os }}-${{ matrix.build_type }}-cargo-${{ matrix.rust_toolchain }}-${{ hashFiles('Cargo.lock') }}
- name: Get Neon artifact for restoration
uses: actions/download-artifact@v3
with:
name: neon-${{ runner.os }}-${{ matrix.build_type }}-${{ matrix.rust_toolchain }}-artifact
path: ./neon-artifact/
- name: Extract Neon artifact
run: |
mkdir -p /tmp/neon/
tar -xf ./neon-artifact/neon.tgz -C /tmp/neon/
rm -rf ./neon-artifact/
- name: Restore coverage data
uses: actions/download-artifact@v3
with:
name: coverage-data-artifact
path: /tmp/neon/coverage/
- name: Build and upload coverage report
run: |
COMMIT_SHA=${{ github.event.pull_request.head.sha }}
COMMIT_SHA=${COMMIT_SHA:-${{ github.sha }}}
COMMIT_URL=https://github.com/${{ github.repository }}/commit/$COMMIT_SHA
scripts/coverage \
--dir=/tmp/neon/coverage report \
--input-objects=/tmp/neon/coverage/binaries.list \
--commit-url=$COMMIT_URL \
--format=github
REPORT_URL=https://${{ github.repository_owner }}.github.io/zenith-coverage-data/$COMMIT_SHA
scripts/git-upload \
--repo=https://${{ secrets.VIP_VAP_ACCESS_TOKEN }}@github.com/${{ github.repository_owner }}/zenith-coverage-data.git \
--message="Add code coverage for $COMMIT_URL" \
copy /tmp/neon/coverage/report $COMMIT_SHA # COPY FROM TO_RELATIVE
# Add link to the coverage report to the commit
curl -f -X POST \
https://api.github.com/repos/${{ github.repository }}/statuses/$COMMIT_SHA \
-H "Accept: application/vnd.github.v3+json" \
--user "${{ secrets.CI_ACCESS_TOKEN }}" \
--data \
"{
\"state\": \"success\",
\"context\": \"neon-coverage\",
\"description\": \"Coverage report is ready\",
\"target_url\": \"$REPORT_URL\"
}"
trigger-e2e-tests:
runs-on: [ self-hosted, Linux, k8s-runner ]
needs: [ build-neon ]
steps:
- name: Set PR's status to pending and request a remote CI test
run: |
COMMIT_SHA=${{ github.event.pull_request.head.sha }}
COMMIT_SHA=${COMMIT_SHA:-${{ github.sha }}}
REMOTE_REPO="${{ github.repository_owner }}/cloud"
curl -f -X POST \
https://api.github.com/repos/${{ github.repository }}/statuses/$COMMIT_SHA \
-H "Accept: application/vnd.github.v3+json" \
--user "${{ secrets.CI_ACCESS_TOKEN }}" \
--data \
"{
\"state\": \"pending\",
\"context\": \"neon-cloud-e2e\",
\"description\": \"[$REMOTE_REPO] Remote CI job is about to start\"
}"
curl -f -X POST \
https://api.github.com/repos/$REMOTE_REPO/actions/workflows/testing.yml/dispatches \
-H "Accept: application/vnd.github.v3+json" \
--user "${{ secrets.CI_ACCESS_TOKEN }}" \
--data \
"{
\"ref\": \"main\",
\"inputs\": {
\"ci_job_name\": \"neon-cloud-e2e\",
\"commit_hash\": \"$COMMIT_SHA\",
\"remote_repo\": \"${{ github.repository }}\"
}
}"

View File

@@ -1,4 +1,4 @@
name: Build and Test
name: Check code style and build
on:
push:
@@ -6,9 +6,21 @@ on:
- main
pull_request:
defaults:
run:
shell: bash -ex {0}
concurrency:
group: ${{ github.workflow }}-${{ github.ref }}
cancel-in-progress: true
env:
RUST_BACKTRACE: 1
jobs:
regression-check:
check-codestyle-rust:
strategy:
fail-fast: false
matrix:
# If we want to duplicate this job for different
# Rust toolchains (e.g. nightly or 1.37.0), add them here.
@@ -92,5 +104,30 @@ jobs:
- name: Run cargo clippy
run: ./run_clippy.sh
- name: Run cargo test
run: cargo test --all --all-targets
- name: Ensure all project builds
run: cargo build --all --all-targets
check-codestyle-python:
runs-on: [ self-hosted, Linux, k8s-runner ]
steps:
- name: Checkout
uses: actions/checkout@v3
with:
submodules: false
fetch-depth: 1
- name: Cache poetry deps
id: cache_poetry
uses: actions/cache@v3
with:
path: ~/.cache/pypoetry/virtualenvs
key: v1-codestyle-python-deps-${{ hashFiles('poetry.lock') }}
- name: Install Python deps
run: ./scripts/pysync
- name: Run yapf to ensure code format
run: poetry run yapf --recursive --diff .
- name: Run mypy to check types
run: poetry run mypy .

138
Cargo.lock generated
View File

@@ -64,6 +64,45 @@ dependencies = [
"nodrop",
]
[[package]]
name = "asn1-rs"
version = "0.3.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "30ff05a702273012438132f449575dbc804e27b2f3cbe3069aa237d26c98fa33"
dependencies = [
"asn1-rs-derive",
"asn1-rs-impl",
"displaydoc",
"nom",
"num-traits",
"rusticata-macros",
"thiserror",
"time 0.3.9",
]
[[package]]
name = "asn1-rs-derive"
version = "0.1.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "db8b7511298d5b7784b40b092d9e9dcd3a627a5707e4b5e507931ab0d44eeebf"
dependencies = [
"proc-macro2",
"quote",
"syn",
"synstructure",
]
[[package]]
name = "asn1-rs-impl"
version = "0.1.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "2777730b2039ac0f95f093556e61b6d26cebed5393ca6f152717777cec3a42ed"
dependencies = [
"proc-macro2",
"quote",
"syn",
]
[[package]]
name = "async-stream"
version = "0.3.3"
@@ -422,6 +461,7 @@ dependencies = [
"tar",
"tokio",
"tokio-postgres",
"urlencoding",
"workspace_hack",
]
@@ -610,7 +650,7 @@ dependencies = [
"crossterm_winapi",
"libc",
"mio",
"parking_lot 0.12.0",
"parking_lot 0.12.1",
"signal-hook",
"signal-hook-mio",
"winapi",
@@ -712,6 +752,12 @@ dependencies = [
"syn",
]
[[package]]
name = "data-encoding"
version = "2.3.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "3ee2393c4a91429dffb4bedf19f4d6abf27d8a732c8ce4980305d782e5426d57"
[[package]]
name = "debugid"
version = "0.7.3"
@@ -721,6 +767,20 @@ dependencies = [
"uuid",
]
[[package]]
name = "der-parser"
version = "7.0.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "fe398ac75057914d7d07307bf67dc7f3f574a26783b4fc7805a20ffa9f506e82"
dependencies = [
"asn1-rs",
"displaydoc",
"nom",
"num-bigint",
"num-traits",
"rusticata-macros",
]
[[package]]
name = "digest"
version = "0.9.0"
@@ -762,6 +822,17 @@ dependencies = [
"winapi",
]
[[package]]
name = "displaydoc"
version = "0.2.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "3bf95dc3f046b9da4f2d51833c0d3547d8564ef6910f5c1ed130306a75b92886"
dependencies = [
"proc-macro2",
"quote",
"syn",
]
[[package]]
name = "either"
version = "1.6.1"
@@ -1731,6 +1802,15 @@ dependencies = [
"memchr",
]
[[package]]
name = "oid-registry"
version = "0.4.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "38e20717fa0541f39bd146692035c37bedfa532b3e5071b35761082407546b2a"
dependencies = [
"asn1-rs",
]
[[package]]
name = "once_cell"
version = "1.10.0"
@@ -1819,6 +1899,7 @@ dependencies = [
"metrics",
"nix",
"once_cell",
"parking_lot 0.12.1",
"postgres",
"postgres-protocol",
"postgres-types",
@@ -1842,6 +1923,7 @@ dependencies = [
"tracing",
"url",
"utils",
"walkdir",
"workspace_hack",
]
@@ -1858,9 +1940,9 @@ dependencies = [
[[package]]
name = "parking_lot"
version = "0.12.0"
version = "0.12.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "87f5ec2493a61ac0506c0f4199f99070cbe83857b0337006a30f3e6719b8ef58"
checksum = "3742b2c103b9f06bc9fff0a37ff4912935851bee6d36f3c02bcc755bcfec228f"
dependencies = [
"lock_api",
"parking_lot_core 0.9.2",
@@ -2227,7 +2309,7 @@ dependencies = [
"lazy_static",
"md5",
"metrics",
"parking_lot 0.12.0",
"parking_lot 0.12.1",
"pin-project-lite",
"rand",
"rcgen",
@@ -2249,6 +2331,7 @@ dependencies = [
"url",
"utils",
"workspace_hack",
"x509-parser",
]
[[package]]
@@ -2620,6 +2703,15 @@ dependencies = [
"semver",
]
[[package]]
name = "rusticata-macros"
version = "4.1.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "faf0c4a6ece9950b9abdb62b1cfcf2a68b3b67a10ba445b3bb85be2a293d0632"
dependencies = [
"nom",
]
[[package]]
name = "rustls"
version = "0.20.4"
@@ -3059,6 +3151,18 @@ version = "0.1.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "20518fe4a4c9acf048008599e464deb21beeae3d3578418951a189c235a7a9a8"
[[package]]
name = "synstructure"
version = "0.12.6"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f36bdaa60a83aca3921b5259d5400cbf5e90fc51931376a9bd4a0eb79aa7210f"
dependencies = [
"proc-macro2",
"quote",
"syn",
"unicode-xid",
]
[[package]]
name = "tar"
version = "0.4.38"
@@ -3253,7 +3357,7 @@ dependencies = [
"fallible-iterator",
"futures",
"log",
"parking_lot 0.12.0",
"parking_lot 0.12.1",
"percent-encoding",
"phf",
"pin-project-lite",
@@ -3582,6 +3686,12 @@ dependencies = [
"percent-encoding",
]
[[package]]
name = "urlencoding"
version = "2.1.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "68b90931029ab9b034b300b797048cf23723400aa757e8a2bfb9d748102f9821"
[[package]]
name = "utils"
version = "0.1.0"
@@ -3921,6 +4031,24 @@ dependencies = [
"tracing-core",
]
[[package]]
name = "x509-parser"
version = "0.13.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "9fb9bace5b5589ffead1afb76e43e34cff39cd0f3ce7e170ae0c29e53b88eb1c"
dependencies = [
"asn1-rs",
"base64",
"data-encoding",
"der-parser",
"lazy_static",
"nom",
"oid-registry",
"rusticata-macros",
"thiserror",
"time 0.3.9",
]
[[package]]
name = "xattr"
version = "0.2.2"

View File

@@ -1,5 +1,5 @@
# Build Postgres
FROM zimg/rust:1.58 AS pg-build
FROM neondatabase/rust:1.58 AS pg-build
WORKDIR /pg
USER root
@@ -14,7 +14,7 @@ RUN set -e \
&& tar -C tmp_install -czf /postgres_install.tar.gz .
# Build zenith binaries
FROM zimg/rust:1.58 AS build
FROM neondatabase/rust:1.58 AS build
ARG GIT_VERSION=local
ARG CACHEPOT_BUCKET=zenith-rust-cachepot
@@ -46,9 +46,9 @@ RUN set -e \
&& useradd -d /data zenith \
&& chown -R zenith:zenith /data
COPY --from=build --chown=zenith:zenith /home/circleci/project/target/release/pageserver /usr/local/bin
COPY --from=build --chown=zenith:zenith /home/circleci/project/target/release/safekeeper /usr/local/bin
COPY --from=build --chown=zenith:zenith /home/circleci/project/target/release/proxy /usr/local/bin
COPY --from=build --chown=zenith:zenith /home/runner/target/release/pageserver /usr/local/bin
COPY --from=build --chown=zenith:zenith /home/runner/target/release/safekeeper /usr/local/bin
COPY --from=build --chown=zenith:zenith /home/runner/target/release/proxy /usr/local/bin
COPY --from=pg-build /pg/tmp_install/ /usr/local/
COPY --from=pg-build /postgres_install.tar.gz /data/

View File

@@ -1,6 +1,6 @@
# First transient image to build compute_tools binaries
# NB: keep in sync with rust image version in .circle/config.yml
FROM zimg/rust:1.58 AS rust-build
FROM neondatabase/rust:1.58 AS rust-build
ARG CACHEPOT_BUCKET=zenith-rust-cachepot
ARG AWS_ACCESS_KEY_ID
@@ -15,4 +15,4 @@ RUN set -e \
# Final image that only has one binary
FROM debian:buster-slim
COPY --from=rust-build /home/circleci/project/target/release/compute_ctl /usr/local/bin/compute_ctl
COPY --from=rust-build /home/runner/target/release/compute_ctl /usr/local/bin/compute_ctl

View File

@@ -29,7 +29,7 @@ Pageserver consists of:
## Running local installation
#### building on Linux
#### Installing dependencies on Linux
1. Install build dependencies and other useful packages
* On Ubuntu or Debian this set of packages should be sufficient to build the code:
@@ -49,18 +49,11 @@ dnf install flex bison readline-devel zlib-devel openssl-devel \
curl --proto '=https' --tlsv1.2 -sSf https://sh.rustup.rs | sh
```
3. Build neon and patched postgres
```sh
git clone --recursive https://github.com/neondatabase/neon.git
cd neon
make -j`nproc`
```
#### building on OSX (12.3.1)
#### Installing dependencies on OSX (12.3.1)
1. Install XCode and dependencies
```
xcode-select --install
brew install protobuf etcd
brew install protobuf etcd openssl
```
2. [Install Rust](https://www.rust-lang.org/tools/install)
@@ -76,10 +69,19 @@ brew install libpq
brew link --force libpq
```
4. Build neon and patched postgres
```sh
#### Building on Linux and OSX
1. Build neon and patched postgres
```
# Note: The path to the neon sources can not contain a space.
git clone --recursive https://github.com/neondatabase/neon.git
cd neon
# The preferred and default is to make a debug build. This will create a
# demonstrably slower build than a release build. If you want to use a release
# build, utilize "`BUILD_TYPE=release make -j`nproc``"
make -j`nproc`
```
@@ -209,7 +211,7 @@ Same applies to certain spelling: i.e. we use MB to denote 1024 * 1024 bytes, wh
To get more familiar with this aspect, refer to:
- [Neon glossary](/docs/glossary.md)
- [PostgreSQL glossary](https://www.postgresql.org/docs/13/glossary.html)
- [PostgreSQL glossary](https://www.postgresql.org/docs/14/glossary.html)
- Other PostgreSQL documentation and sources (Neon fork sources can be found [here](https://github.com/neondatabase/postgres))
## Join the development

View File

@@ -18,4 +18,5 @@ serde_json = "1"
tar = "0.4"
tokio = { version = "1.17", features = ["macros", "rt", "rt-multi-thread"] }
tokio-postgres = { git = "https://github.com/zenithdb/rust-postgres.git", rev="d052ee8b86fff9897c77b0fe89ea9daba0e1fa38" }
urlencoding = "2.1.0"
workspace_hack = { version = "0.1", path = "../workspace_hack" }

View File

@@ -289,6 +289,7 @@ impl ComputeNode {
handle_roles(&self.spec, &mut client)?;
handle_databases(&self.spec, &mut client)?;
handle_role_deletions(self, &mut client)?;
handle_grants(&self.spec, &mut client)?;
create_writablity_check_data(&mut client)?;

View File

@@ -2,9 +2,11 @@ use std::path::Path;
use anyhow::Result;
use log::{info, log_enabled, warn, Level};
use postgres::Client;
use postgres::{Client, NoTls};
use serde::Deserialize;
use urlencoding::encode;
use crate::compute::ComputeNode;
use crate::config;
use crate::params::PG_HBA_ALL_MD5;
use crate::pg_helpers::*;
@@ -97,18 +99,13 @@ pub fn handle_roles(spec: &ComputeSpec, client: &mut Client) -> Result<()> {
// Process delta operations first
if let Some(ops) = &spec.delta_operations {
info!("processing delta operations on roles");
info!("processing role renames");
for op in ops {
match op.action.as_ref() {
// We do not check either role exists or not,
// Postgres will take care of it for us
"delete_role" => {
let query: String = format!("DROP ROLE IF EXISTS {}", &op.name.quote());
warn!("deleting role '{}'", &op.name);
xact.execute(query.as_str(), &[])?;
// no-op now, roles will be deleted at the end of configuration
}
// Renaming role drops its password, since tole name is
// Renaming role drops its password, since role name is
// used as a salt there. It is important that this role
// is recorded with a new `name` in the `roles` list.
// Follow up roles update will set the new password.
@@ -182,7 +179,7 @@ pub fn handle_roles(spec: &ComputeSpec, client: &mut Client) -> Result<()> {
xact.execute(query.as_str(), &[])?;
let grant_query = format!(
"grant pg_read_all_data, pg_write_all_data to {}",
"GRANT pg_read_all_data, pg_write_all_data TO {}",
name.quote()
);
xact.execute(grant_query.as_str(), &[])?;
@@ -197,6 +194,68 @@ pub fn handle_roles(spec: &ComputeSpec, client: &mut Client) -> Result<()> {
Ok(())
}
/// Reassign all dependent objects and delete requested roles.
pub fn handle_role_deletions(node: &ComputeNode, client: &mut Client) -> Result<()> {
let spec = &node.spec;
// First, reassign all dependent objects to db owners.
if let Some(ops) = &spec.delta_operations {
info!("reassigning dependent objects of to-be-deleted roles");
for op in ops {
if op.action == "delete_role" {
reassign_owned_objects(node, &op.name)?;
}
}
}
// Second, proceed with role deletions.
let mut xact = client.transaction()?;
if let Some(ops) = &spec.delta_operations {
info!("processing role deletions");
for op in ops {
// We do not check either role exists or not,
// Postgres will take care of it for us
if op.action == "delete_role" {
let query: String = format!("DROP ROLE IF EXISTS {}", &op.name.quote());
warn!("deleting role '{}'", &op.name);
xact.execute(query.as_str(), &[])?;
}
}
}
Ok(())
}
// Reassign all owned objects in all databases to the owner of the database.
fn reassign_owned_objects(node: &ComputeNode, role_name: &PgIdent) -> Result<()> {
for db in &node.spec.cluster.databases {
if db.owner != *role_name {
let db_name_encoded = format!("/{}", encode(&db.name));
let db_connstr = node.connstr.replacen("/postgres", &db_name_encoded, 1);
let mut client = Client::connect(&db_connstr, NoTls)?;
// This will reassign all dependent objects to the db owner
let reassign_query = format!(
"REASSIGN OWNED BY {} TO {}",
role_name.quote(),
db.owner.quote()
);
info!(
"reassigning objects owned by '{}' in db '{}' to '{}'",
role_name, &db.name, &db.owner
);
client.simple_query(&reassign_query)?;
// This now will only drop privileges of the role
let drop_query = format!("DROP OWNED BY {}", role_name.quote());
client.simple_query(&drop_query)?;
}
}
Ok(())
}
/// It follows mostly the same logic as `handle_roles()` excepting that we
/// does not use an explicit transactions block, since major database operations
/// like `CREATE DATABASE` and `DROP DATABASE` do not support it. Statement-level
@@ -294,13 +353,26 @@ pub fn handle_databases(spec: &ComputeSpec, client: &mut Client) -> Result<()> {
pub fn handle_grants(spec: &ComputeSpec, client: &mut Client) -> Result<()> {
info!("cluster spec grants:");
// We now have a separate `web_access` role to connect to the database
// via the web interface and proxy link auth. And also we grant a
// read / write all data privilege to every role. So also grant
// create to everyone.
// XXX: later we should stop messing with Postgres ACL in such horrible
// ways.
let roles = spec
.cluster
.roles
.iter()
.map(|r| r.name.quote())
.collect::<Vec<_>>();
for db in &spec.cluster.databases {
let dbname = &db.name;
let query: String = format!(
"GRANT CREATE ON DATABASE {} TO {}",
dbname.quote(),
db.owner.quote()
roles.join(", ")
);
info!("grant query {}", &query);

View File

@@ -1,5 +1,6 @@
use std::collections::HashMap;
use std::io::Write;
use std::fs::File;
use std::io::{BufReader, Write};
use std::net::TcpStream;
use std::num::NonZeroU64;
use std::path::PathBuf;
@@ -527,4 +528,54 @@ impl PageServerNode {
Ok(timeline_info_response)
}
/// Import a basebackup prepared using either:
/// a) `pg_basebackup -F tar`, or
/// b) The `fullbackup` pageserver endpoint
///
/// # Arguments
/// * `tenant_id` - tenant to import into. Created if not exists
/// * `timeline_id` - id to assign to imported timeline
/// * `base` - (start lsn of basebackup, path to `base.tar` file)
/// * `pg_wal` - if there's any wal to import: (end lsn, path to `pg_wal.tar`)
pub fn timeline_import(
&self,
tenant_id: ZTenantId,
timeline_id: ZTimelineId,
base: (Lsn, PathBuf),
pg_wal: Option<(Lsn, PathBuf)>,
) -> anyhow::Result<()> {
let mut client = self.pg_connection_config.connect(NoTls).unwrap();
// Init base reader
let (start_lsn, base_tarfile_path) = base;
let base_tarfile = File::open(base_tarfile_path)?;
let mut base_reader = BufReader::new(base_tarfile);
// Init wal reader if necessary
let (end_lsn, wal_reader) = if let Some((end_lsn, wal_tarfile_path)) = pg_wal {
let wal_tarfile = File::open(wal_tarfile_path)?;
let wal_reader = BufReader::new(wal_tarfile);
(end_lsn, Some(wal_reader))
} else {
(start_lsn, None)
};
// Import base
let import_cmd =
format!("import basebackup {tenant_id} {timeline_id} {start_lsn} {end_lsn}");
let mut writer = client.copy_in(&import_cmd)?;
io::copy(&mut base_reader, &mut writer)?;
writer.finish()?;
// Import wal if necessary
if let Some(mut wal_reader) = wal_reader {
let import_cmd = format!("import wal {tenant_id} {timeline_id} {start_lsn} {end_lsn}");
let mut writer = client.copy_in(&import_cmd)?;
io::copy(&mut wal_reader, &mut writer)?;
writer.finish()?;
}
Ok(())
}
}

View File

@@ -36,12 +36,12 @@ This is how the `LOGICAL_TIMELINE_SIZE` metric is implemented in the pageserver.
Alternatively, we could count only relation data. As in pg_database_size().
This approach is somewhat more user-friendly because it is the data that is really affected by the user.
On the other hand, it puts us in a weaker position than other services, i.e., RDS.
We will need to refactor the timeline_size counter or add another counter to implement it.
We will need to refactor the timeline_size counter or add another counter to implement it.
Timeline size is updated during wal digestion. It is not versioned and is valid at the last_received_lsn moment.
Then this size should be reported to compute node.
`current_timeline_size` value is included in the walreceiver's custom feedback message: `ZenithFeedback.`
`current_timeline_size` value is included in the walreceiver's custom feedback message: `ReplicationFeedback.`
(PR about protocol changes https://github.com/zenithdb/zenith/pull/1037).
@@ -64,11 +64,11 @@ We should warn users if the limit is soon to be reached.
### **Reliability, failure modes and corner cases**
1. `current_timeline_size` is valid at the last received and digested by pageserver lsn.
If pageserver lags behind compute node, `current_timeline_size` will lag too. This lag can be tuned using backpressure, but it is not expected to be 0 all the time.
So transactions that happen in this lsn range may cause limit overflow. Especially operations that generate (i.e., CREATE DATABASE) or free (i.e., TRUNCATE) a lot of data pages while generating a small amount of WAL. Are there other operations like this?
Currently, CREATE DATABASE operations are restricted in the console. So this is not an issue.

View File

@@ -6,17 +6,13 @@ pub mod subscription_key;
/// All broker values, possible to use when dealing with etcd.
pub mod subscription_value;
use std::{
collections::{hash_map, HashMap},
str::FromStr,
};
use std::str::FromStr;
use serde::de::DeserializeOwned;
use subscription_key::SubscriptionKey;
use tokio::{sync::mpsc, task::JoinHandle};
use tracing::*;
use utils::zid::{NodeId, ZTenantTimelineId};
use crate::subscription_key::SubscriptionFullKey;
@@ -28,18 +24,17 @@ pub const DEFAULT_NEON_BROKER_ETCD_PREFIX: &str = "neon";
/// A way to control the data retrieval from a certain subscription.
pub struct BrokerSubscription<V> {
value_updates: mpsc::UnboundedReceiver<HashMap<ZTenantTimelineId, HashMap<NodeId, V>>>,
/// An unbounded channel to fetch the relevant etcd updates from.
pub value_updates: mpsc::UnboundedReceiver<BrokerUpdate<V>>,
key: SubscriptionKey,
watcher_handle: JoinHandle<Result<(), BrokerError>>,
/// A subscription task handle, to allow waiting on it for the task to complete.
/// Both the updates channel and the handle require `&mut`, so it's better to keep
/// both `pub` to allow using both in the same structures without borrow checker complaining.
pub watcher_handle: JoinHandle<Result<(), BrokerError>>,
watcher: Watcher,
}
impl<V> BrokerSubscription<V> {
/// Asynchronously polls for more data from the subscription, suspending the current future if there's no data sent yet.
pub async fn fetch_data(&mut self) -> Option<HashMap<ZTenantTimelineId, HashMap<NodeId, V>>> {
self.value_updates.recv().await
}
/// Cancels the subscription, stopping the data poller and waiting for it to shut down.
pub async fn cancel(mut self) -> Result<(), BrokerError> {
self.watcher.cancel().await.map_err(|e| {
@@ -48,15 +43,41 @@ impl<V> BrokerSubscription<V> {
format!("Failed to cancel broker subscription, kind: {:?}", self.key),
)
})?;
self.watcher_handle.await.map_err(|e| {
BrokerError::InternalError(format!(
"Failed to join the broker value updates task, kind: {:?}, error: {e}",
self.key
))
})?
match (&mut self.watcher_handle).await {
Ok(res) => res,
Err(e) => {
if e.is_cancelled() {
// don't error on the tasks that are cancelled already
Ok(())
} else {
Err(BrokerError::InternalError(format!(
"Panicked during broker subscription task, kind: {:?}, error: {e}",
self.key
)))
}
}
}
}
}
impl<V> Drop for BrokerSubscription<V> {
fn drop(&mut self) {
// we poll data from etcd into the channel in the same struct, so if the whole struct gets dropped,
// no more data is used by the receiver and it's safe to cancel and drop the whole etcd subscription task.
self.watcher_handle.abort();
}
}
/// An update from the etcd broker.
pub struct BrokerUpdate<V> {
/// Etcd generation version, the bigger the more actual the data is.
pub etcd_version: i64,
/// Etcd key for the corresponding value, parsed from the broker KV.
pub key: SubscriptionFullKey,
/// Current etcd value, parsed from the broker KV.
pub value: V,
}
#[derive(Debug, thiserror::Error)]
pub enum BrokerError {
#[error("Etcd client error: {0}. Context: {1}")]
@@ -124,41 +145,21 @@ where
break;
}
let mut value_updates: HashMap<ZTenantTimelineId, HashMap<NodeId, V>> = HashMap::new();
// Keep track that the timeline data updates from etcd arrive in the right order.
// https://etcd.io/docs/v3.5/learning/api_guarantees/#isolation-level-and-consistency-of-replicas
// > etcd does not ensure linearizability for watch operations. Users are expected to verify the revision of watch responses to ensure correct ordering.
let mut value_etcd_versions: HashMap<ZTenantTimelineId, i64> = HashMap::new();
let events = resp.events();
debug!("Processing {} events", events.len());
for event in events {
if EventType::Put == event.event_type() {
if let Some(new_etcd_kv) = event.kv() {
let new_kv_version = new_etcd_kv.version();
match parse_etcd_kv(new_etcd_kv, &value_parser, &key.cluster_prefix) {
Ok(Some((key, value))) => match value_updates
.entry(key.id)
.or_default()
.entry(key.node_id)
{
hash_map::Entry::Occupied(mut o) => {
let old_etcd_kv_version = value_etcd_versions.get(&key.id).copied().unwrap_or(i64::MIN);
if old_etcd_kv_version < new_kv_version {
o.insert(value);
value_etcd_versions.insert(key.id,new_kv_version);
} else {
debug!("Skipping etcd timeline update due to older version compared to one that's already stored");
}
}
hash_map::Entry::Vacant(v) => {
v.insert(value);
value_etcd_versions.insert(key.id,new_kv_version);
}
},
Ok(Some((key, value))) => if let Err(e) = value_updates_sender.send(BrokerUpdate {
etcd_version: new_etcd_kv.version(),
key,
value,
}) {
info!("Broker value updates for key {key:?} sender got dropped, exiting: {e}");
break;
},
Ok(None) => debug!("Ignoring key {key:?} : no value was returned by the parser"),
Err(BrokerError::KeyNotParsed(e)) => debug!("Unexpected key {key:?} for timeline update: {e}"),
Err(e) => error!("Failed to represent etcd KV {new_etcd_kv:?}: {e}"),
@@ -166,13 +167,6 @@ where
}
}
}
if !value_updates.is_empty() {
if let Err(e) = value_updates_sender.send(value_updates) {
info!("Broker value updates for key {key:?} sender got dropped, exiting: {e}");
break;
}
}
}
Ok(())

View File

@@ -926,10 +926,10 @@ impl<'a> BeMessage<'a> {
}
}
// Zenith extension of postgres replication protocol
// See ZENITH_STATUS_UPDATE_TAG_BYTE
// Neon extension of postgres replication protocol
// See NEON_STATUS_UPDATE_TAG_BYTE
#[derive(Debug, Clone, Copy, PartialEq, Serialize, Deserialize)]
pub struct ZenithFeedback {
pub struct ReplicationFeedback {
// Last known size of the timeline. Used to enforce timeline size limit.
pub current_timeline_size: u64,
// Parts of StandbyStatusUpdate we resend to compute via safekeeper
@@ -939,13 +939,13 @@ pub struct ZenithFeedback {
pub ps_replytime: SystemTime,
}
// NOTE: Do not forget to increment this number when adding new fields to ZenithFeedback.
// NOTE: Do not forget to increment this number when adding new fields to ReplicationFeedback.
// Do not remove previously available fields because this might be backwards incompatible.
pub const ZENITH_FEEDBACK_FIELDS_NUMBER: u8 = 5;
pub const REPLICATION_FEEDBACK_FIELDS_NUMBER: u8 = 5;
impl ZenithFeedback {
pub fn empty() -> ZenithFeedback {
ZenithFeedback {
impl ReplicationFeedback {
pub fn empty() -> ReplicationFeedback {
ReplicationFeedback {
current_timeline_size: 0,
ps_writelsn: 0,
ps_applylsn: 0,
@@ -954,7 +954,7 @@ impl ZenithFeedback {
}
}
// Serialize ZenithFeedback using custom format
// Serialize ReplicationFeedback using custom format
// to support protocol extensibility.
//
// Following layout is used:
@@ -965,7 +965,7 @@ impl ZenithFeedback {
// uint32 - value length in bytes
// value itself
pub fn serialize(&self, buf: &mut BytesMut) -> Result<()> {
buf.put_u8(ZENITH_FEEDBACK_FIELDS_NUMBER); // # of keys
buf.put_u8(REPLICATION_FEEDBACK_FIELDS_NUMBER); // # of keys
write_cstr(&Bytes::from("current_timeline_size"), buf)?;
buf.put_i32(8);
buf.put_u64(self.current_timeline_size);
@@ -992,9 +992,9 @@ impl ZenithFeedback {
Ok(())
}
// Deserialize ZenithFeedback message
pub fn parse(mut buf: Bytes) -> ZenithFeedback {
let mut zf = ZenithFeedback::empty();
// Deserialize ReplicationFeedback message
pub fn parse(mut buf: Bytes) -> ReplicationFeedback {
let mut zf = ReplicationFeedback::empty();
let nfields = buf.get_u8();
let mut i = 0;
while i < nfields {
@@ -1035,14 +1035,14 @@ impl ZenithFeedback {
_ => {
let len = buf.get_i32();
warn!(
"ZenithFeedback parse. unknown key {} of len {}. Skip it.",
"ReplicationFeedback parse. unknown key {} of len {}. Skip it.",
key, len
);
buf.advance(len as usize);
}
}
}
trace!("ZenithFeedback parsed is {:?}", zf);
trace!("ReplicationFeedback parsed is {:?}", zf);
zf
}
}
@@ -1052,8 +1052,8 @@ mod tests {
use super::*;
#[test]
fn test_zenithfeedback_serialization() {
let mut zf = ZenithFeedback::empty();
fn test_replication_feedback_serialization() {
let mut zf = ReplicationFeedback::empty();
// Fill zf with some values
zf.current_timeline_size = 12345678;
// Set rounded time to be able to compare it with deserialized value,
@@ -1062,13 +1062,13 @@ mod tests {
let mut data = BytesMut::new();
zf.serialize(&mut data).unwrap();
let zf_parsed = ZenithFeedback::parse(data.freeze());
let zf_parsed = ReplicationFeedback::parse(data.freeze());
assert_eq!(zf, zf_parsed);
}
#[test]
fn test_zenithfeedback_unknown_key() {
let mut zf = ZenithFeedback::empty();
fn test_replication_feedback_unknown_key() {
let mut zf = ReplicationFeedback::empty();
// Fill zf with some values
zf.current_timeline_size = 12345678;
// Set rounded time to be able to compare it with deserialized value,
@@ -1079,7 +1079,7 @@ mod tests {
// Add an extra field to the buffer and adjust number of keys
if let Some(first) = data.first_mut() {
*first = ZENITH_FEEDBACK_FIELDS_NUMBER + 1;
*first = REPLICATION_FEEDBACK_FIELDS_NUMBER + 1;
}
write_cstr(&Bytes::from("new_field_one"), &mut data).unwrap();
@@ -1087,7 +1087,7 @@ mod tests {
data.put_u64(42);
// Parse serialized data and check that new field is not parsed
let zf_parsed = ZenithFeedback::parse(data.freeze());
let zf_parsed = ReplicationFeedback::parse(data.freeze());
assert_eq!(zf, zf_parsed);
}

View File

@@ -14,7 +14,7 @@ use safekeeper::defaults::{
DEFAULT_PG_LISTEN_PORT as DEFAULT_SAFEKEEPER_PG_PORT,
};
use std::collections::{BTreeSet, HashMap};
use std::path::Path;
use std::path::{Path, PathBuf};
use std::process::exit;
use std::str::FromStr;
use utils::{
@@ -159,6 +159,20 @@ fn main() -> Result<()> {
.about("Create a new blank timeline")
.arg(tenant_id_arg.clone())
.arg(branch_name_arg.clone()))
.subcommand(App::new("import")
.about("Import timeline from basebackup directory")
.arg(tenant_id_arg.clone())
.arg(timeline_id_arg.clone())
.arg(Arg::new("node-name").long("node-name").takes_value(true)
.help("Name to assign to the imported timeline"))
.arg(Arg::new("base-tarfile").long("base-tarfile").takes_value(true)
.help("Basebackup tarfile to import"))
.arg(Arg::new("base-lsn").long("base-lsn").takes_value(true)
.help("Lsn the basebackup starts at"))
.arg(Arg::new("wal-tarfile").long("wal-tarfile").takes_value(true)
.help("Wal to add after base"))
.arg(Arg::new("end-lsn").long("end-lsn").takes_value(true)
.help("Lsn the basebackup ends at")))
).subcommand(
App::new("tenant")
.setting(AppSettings::ArgRequiredElseHelp)
@@ -613,6 +627,43 @@ fn handle_timeline(timeline_match: &ArgMatches, env: &mut local_env::LocalEnv) -
timeline.timeline_id, last_record_lsn, tenant_id,
);
}
Some(("import", import_match)) => {
let tenant_id = get_tenant_id(import_match, env)?;
let timeline_id = parse_timeline_id(import_match)?.expect("No timeline id provided");
let name = import_match
.value_of("node-name")
.ok_or_else(|| anyhow!("No node name provided"))?;
// Parse base inputs
let base_tarfile = import_match
.value_of("base-tarfile")
.map(|s| PathBuf::from_str(s).unwrap())
.ok_or_else(|| anyhow!("No base-tarfile provided"))?;
let base_lsn = Lsn::from_str(
import_match
.value_of("base-lsn")
.ok_or_else(|| anyhow!("No base-lsn provided"))?,
)?;
let base = (base_lsn, base_tarfile);
// Parse pg_wal inputs
let wal_tarfile = import_match
.value_of("wal-tarfile")
.map(|s| PathBuf::from_str(s).unwrap());
let end_lsn = import_match
.value_of("end-lsn")
.map(|s| Lsn::from_str(s).unwrap());
// TODO validate both or none are provided
let pg_wal = end_lsn.zip(wal_tarfile);
let mut cplane = ComputeControlPlane::load(env.clone())?;
println!("Importing timeline into pageserver ...");
pageserver.timeline_import(tenant_id, timeline_id, base, pg_wal)?;
println!("Creating node for imported timeline ...");
env.register_branch_mapping(name.to_string(), tenant_id, timeline_id)?;
cplane.new_node(tenant_id, name, timeline_id, None, None)?;
println!("Done");
}
Some(("branch", branch_match)) => {
let tenant_id = get_tenant_id(branch_match, env)?;
let new_branch_name = branch_match

View File

@@ -61,6 +61,9 @@ utils = { path = "../libs/utils" }
remote_storage = { path = "../libs/remote_storage" }
workspace_hack = { version = "0.1", path = "../workspace_hack" }
close_fds = "0.3.2"
walkdir = "2.3.2"
parking_lot = "0.12.1"
[dev-dependencies]
hex-literal = "0.3"

View File

@@ -13,6 +13,7 @@
use anyhow::{anyhow, bail, ensure, Context, Result};
use bytes::{BufMut, BytesMut};
use fail::fail_point;
use itertools::Itertools;
use std::fmt::Write as FmtWrite;
use std::io;
use std::io::Write;
@@ -21,7 +22,7 @@ use std::time::SystemTime;
use tar::{Builder, EntryType, Header};
use tracing::*;
use crate::reltag::SlruKind;
use crate::reltag::{RelTag, SlruKind};
use crate::repository::Timeline;
use crate::DatadirTimelineImpl;
use postgres_ffi::xlog_utils::*;
@@ -39,11 +40,12 @@ where
timeline: &'a Arc<DatadirTimelineImpl>,
pub lsn: Lsn,
prev_record_lsn: Lsn,
full_backup: bool,
finished: bool,
}
// Create basebackup with non-rel data in it. Omit relational data.
// Create basebackup with non-rel data in it.
// Only include relational data if 'full_backup' is true.
//
// Currently we use empty lsn in two cases:
// * During the basebackup right after timeline creation
@@ -58,6 +60,7 @@ where
write: W,
timeline: &'a Arc<DatadirTimelineImpl>,
req_lsn: Option<Lsn>,
full_backup: bool,
) -> Result<Basebackup<'a, W>> {
// Compute postgres doesn't have any previous WAL files, but the first
// record that it's going to write needs to include the LSN of the
@@ -94,8 +97,8 @@ where
};
info!(
"taking basebackup lsn={}, prev_lsn={}",
backup_lsn, backup_prev
"taking basebackup lsn={}, prev_lsn={} (full_backup={})",
backup_lsn, backup_prev, full_backup
);
Ok(Basebackup {
@@ -103,11 +106,14 @@ where
timeline,
lsn: backup_lsn,
prev_record_lsn: backup_prev,
full_backup,
finished: false,
})
}
pub fn send_tarball(mut self) -> anyhow::Result<()> {
// TODO include checksum
// Create pgdata subdirs structure
for dir in pg_constants::PGDATA_SUBDIRS.iter() {
let header = new_tar_header_dir(*dir)?;
@@ -140,6 +146,13 @@ where
// Create tablespace directories
for ((spcnode, dbnode), has_relmap_file) in self.timeline.list_dbdirs(self.lsn)? {
self.add_dbdir(spcnode, dbnode, has_relmap_file)?;
// Gather and send relational files in each database if full backup is requested.
if self.full_backup {
for rel in self.timeline.list_rels(spcnode, dbnode, self.lsn)? {
self.add_rel(rel)?;
}
}
}
for xid in self.timeline.list_twophase_files(self.lsn)? {
self.add_twophase_file(xid)?;
@@ -157,6 +170,38 @@ where
Ok(())
}
fn add_rel(&mut self, tag: RelTag) -> anyhow::Result<()> {
let nblocks = self.timeline.get_rel_size(tag, self.lsn)?;
// Function that adds relation segment data to archive
let mut add_file = |segment_index, data: &Vec<u8>| -> anyhow::Result<()> {
let file_name = tag.to_segfile_name(segment_index as u32);
let header = new_tar_header(&file_name, data.len() as u64)?;
self.ar.append(&header, data.as_slice())?;
Ok(())
};
// If the relation is empty, create an empty file
if nblocks == 0 {
add_file(0, &vec![])?;
return Ok(());
}
// Add a file for each chunk of blocks (aka segment)
let chunks = (0..nblocks).chunks(pg_constants::RELSEG_SIZE as usize);
for (seg, blocks) in chunks.into_iter().enumerate() {
let mut segment_data: Vec<u8> = vec![];
for blknum in blocks {
let img = self.timeline.get_rel_page_at_lsn(tag, blknum, self.lsn)?;
segment_data.extend_from_slice(&img[..]);
}
add_file(seg, &segment_data)?;
}
Ok(())
}
//
// Generate SLRU segment files from repository.
//

View File

@@ -2,7 +2,6 @@
//! Import data and WAL from a PostgreSQL data directory and WAL segments into
//! a zenith Timeline.
//!
use std::fs;
use std::fs::File;
use std::io::{Read, Seek, SeekFrom};
use std::path::{Path, PathBuf};
@@ -10,16 +9,18 @@ use std::path::{Path, PathBuf};
use anyhow::{bail, ensure, Context, Result};
use bytes::Bytes;
use tracing::*;
use walkdir::WalkDir;
use crate::pgdatadir_mapping::*;
use crate::reltag::{RelTag, SlruKind};
use crate::repository::Repository;
use crate::repository::Timeline;
use crate::walingest::WalIngest;
use postgres_ffi::relfile_utils::*;
use postgres_ffi::waldecoder::*;
use postgres_ffi::xlog_utils::*;
use postgres_ffi::Oid;
use postgres_ffi::{pg_constants, ControlFileData, DBState_DB_SHUTDOWNED};
use postgres_ffi::{Oid, TransactionId};
use utils::lsn::Lsn;
///
@@ -35,100 +36,29 @@ pub fn import_timeline_from_postgres_datadir<R: Repository>(
) -> Result<()> {
let mut pg_control: Option<ControlFileData> = None;
// TODO this shoud be start_lsn, which is not necessarily equal to end_lsn (aka lsn)
// Then fishing out pg_control would be unnecessary
let mut modification = tline.begin_modification(lsn);
modification.init_empty()?;
// Scan 'global'
let mut relfiles: Vec<PathBuf> = Vec::new();
for direntry in fs::read_dir(path.join("global"))? {
let direntry = direntry?;
match direntry.file_name().to_str() {
None => continue,
// Import all but pg_wal
let all_but_wal = WalkDir::new(path)
.into_iter()
.filter_entry(|entry| !entry.path().ends_with("pg_wal"));
for entry in all_but_wal {
let entry = entry?;
let metadata = entry.metadata().expect("error getting dir entry metadata");
if metadata.is_file() {
let absolute_path = entry.path();
let relative_path = absolute_path.strip_prefix(path)?;
Some("pg_control") => {
pg_control = Some(import_control_file(&mut modification, &direntry.path())?);
}
Some("pg_filenode.map") => {
import_relmap_file(
&mut modification,
pg_constants::GLOBALTABLESPACE_OID,
0,
&direntry.path(),
)?;
}
// Load any relation files into the page server (but only after the other files)
_ => relfiles.push(direntry.path()),
}
}
for relfile in relfiles {
import_relfile(
&mut modification,
&relfile,
pg_constants::GLOBALTABLESPACE_OID,
0,
)?;
}
// Scan 'base'. It contains database dirs, the database OID is the filename.
// E.g. 'base/12345', where 12345 is the database OID.
for direntry in fs::read_dir(path.join("base"))? {
let direntry = direntry?;
//skip all temporary files
if direntry.file_name().to_string_lossy() == "pgsql_tmp" {
continue;
}
let dboid = direntry.file_name().to_string_lossy().parse::<u32>()?;
let mut relfiles: Vec<PathBuf> = Vec::new();
for direntry in fs::read_dir(direntry.path())? {
let direntry = direntry?;
match direntry.file_name().to_str() {
None => continue,
Some("PG_VERSION") => {
//modification.put_dbdir_creation(pg_constants::DEFAULTTABLESPACE_OID, dboid)?;
}
Some("pg_filenode.map") => import_relmap_file(
&mut modification,
pg_constants::DEFAULTTABLESPACE_OID,
dboid,
&direntry.path(),
)?,
// Load any relation files into the page server
_ => relfiles.push(direntry.path()),
let file = File::open(absolute_path)?;
let len = metadata.len() as usize;
if let Some(control_file) = import_file(&mut modification, relative_path, file, len)? {
pg_control = Some(control_file);
}
}
for relfile in relfiles {
import_relfile(
&mut modification,
&relfile,
pg_constants::DEFAULTTABLESPACE_OID,
dboid,
)?;
}
}
for entry in fs::read_dir(path.join("pg_xact"))? {
let entry = entry?;
import_slru_file(&mut modification, SlruKind::Clog, &entry.path())?;
}
for entry in fs::read_dir(path.join("pg_multixact").join("members"))? {
let entry = entry?;
import_slru_file(&mut modification, SlruKind::MultiXactMembers, &entry.path())?;
}
for entry in fs::read_dir(path.join("pg_multixact").join("offsets"))? {
let entry = entry?;
import_slru_file(&mut modification, SlruKind::MultiXactOffsets, &entry.path())?;
}
for entry in fs::read_dir(path.join("pg_twophase"))? {
let entry = entry?;
let xid = u32::from_str_radix(&entry.path().to_string_lossy(), 16)?;
import_twophase_file(&mut modification, xid, &entry.path())?;
}
// TODO: Scan pg_tblspc
// We're done importing all the data files.
modification.commit()?;
@@ -158,31 +88,30 @@ pub fn import_timeline_from_postgres_datadir<R: Repository>(
}
// subroutine of import_timeline_from_postgres_datadir(), to load one relation file.
fn import_relfile<R: Repository>(
fn import_rel<R: Repository, Reader: Read>(
modification: &mut DatadirModification<R>,
path: &Path,
spcoid: Oid,
dboid: Oid,
mut reader: Reader,
len: usize,
) -> anyhow::Result<()> {
// Does it look like a relation file?
trace!("importing rel file {}", path.display());
let (relnode, forknum, segno) = parse_relfilename(&path.file_name().unwrap().to_string_lossy())
.map_err(|e| {
warn!("unrecognized file in postgres datadir: {:?} ({})", path, e);
e
})?;
let filename = &path
.file_name()
.expect("missing rel filename")
.to_string_lossy();
let (relnode, forknum, segno) = parse_relfilename(filename).map_err(|e| {
warn!("unrecognized file in postgres datadir: {:?} ({})", path, e);
e
})?;
let mut file = File::open(path)?;
let mut buf: [u8; 8192] = [0u8; 8192];
let len = file.metadata().unwrap().len();
ensure!(len % pg_constants::BLCKSZ as u64 == 0);
let nblocks = len / pg_constants::BLCKSZ as u64;
if segno != 0 {
todo!();
}
ensure!(len % pg_constants::BLCKSZ as usize == 0);
let nblocks = len / pg_constants::BLCKSZ as usize;
let rel = RelTag {
spcnode: spcoid,
@@ -190,11 +119,22 @@ fn import_relfile<R: Repository>(
relnode,
forknum,
};
modification.put_rel_creation(rel, nblocks as u32)?;
let mut blknum: u32 = segno * (1024 * 1024 * 1024 / pg_constants::BLCKSZ as u32);
// Call put_rel_creation for every segment of the relation,
// because there is no guarantee about the order in which we are processing segments.
// ignore "relation already exists" error
if let Err(e) = modification.put_rel_creation(rel, nblocks as u32) {
if e.to_string().contains("already exists") {
debug!("relation {} already exists. we must be extending it", rel);
} else {
return Err(e);
}
}
loop {
let r = file.read_exact(&mut buf);
let r = reader.read_exact(&mut buf);
match r {
Ok(_) => {
modification.put_rel_page_image(rel, blknum, Bytes::copy_from_slice(&buf))?;
@@ -204,7 +144,9 @@ fn import_relfile<R: Repository>(
Err(err) => match err.kind() {
std::io::ErrorKind::UnexpectedEof => {
// reached EOF. That's expected.
ensure!(blknum == nblocks as u32, "unexpected EOF");
let relative_blknum =
blknum - segno * (1024 * 1024 * 1024 / pg_constants::BLCKSZ as u32);
ensure!(relative_blknum == nblocks as u32, "unexpected EOF");
break;
}
_ => {
@@ -215,96 +157,43 @@ fn import_relfile<R: Repository>(
blknum += 1;
}
// Update relation size
//
// If we process rel segments out of order,
// put_rel_extend will skip the update.
modification.put_rel_extend(rel, blknum)?;
Ok(())
}
/// Import a relmapper (pg_filenode.map) file into the repository
fn import_relmap_file<R: Repository>(
modification: &mut DatadirModification<R>,
spcnode: Oid,
dbnode: Oid,
path: &Path,
) -> Result<()> {
let mut file = File::open(path)?;
let mut buffer = Vec::new();
// read the whole file
file.read_to_end(&mut buffer)?;
trace!("importing relmap file {}", path.display());
modification.put_relmap_file(spcnode, dbnode, Bytes::copy_from_slice(&buffer[..]))?;
Ok(())
}
/// Import a twophase state file (pg_twophase/<xid>) into the repository
fn import_twophase_file<R: Repository>(
modification: &mut DatadirModification<R>,
xid: TransactionId,
path: &Path,
) -> Result<()> {
let mut file = File::open(path)?;
let mut buffer = Vec::new();
// read the whole file
file.read_to_end(&mut buffer)?;
trace!("importing non-rel file {}", path.display());
modification.put_twophase_file(xid, Bytes::copy_from_slice(&buffer[..]))?;
Ok(())
}
///
/// Import pg_control file into the repository.
///
/// The control file is imported as is, but we also extract the checkpoint record
/// from it and store it separated.
fn import_control_file<R: Repository>(
modification: &mut DatadirModification<R>,
path: &Path,
) -> Result<ControlFileData> {
let mut file = File::open(path)?;
let mut buffer = Vec::new();
// read the whole file
file.read_to_end(&mut buffer)?;
trace!("importing control file {}", path.display());
// Import it as ControlFile
modification.put_control_file(Bytes::copy_from_slice(&buffer[..]))?;
// Extract the checkpoint record and import it separately.
let pg_control = ControlFileData::decode(&buffer)?;
let checkpoint_bytes = pg_control.checkPointCopy.encode()?;
modification.put_checkpoint(checkpoint_bytes)?;
Ok(pg_control)
}
///
/// Import an SLRU segment file
///
fn import_slru_file<R: Repository>(
fn import_slru<R: Repository, Reader: Read>(
modification: &mut DatadirModification<R>,
slru: SlruKind,
path: &Path,
mut reader: Reader,
len: usize,
) -> Result<()> {
trace!("importing slru file {}", path.display());
let mut file = File::open(path)?;
let mut buf: [u8; 8192] = [0u8; 8192];
let segno = u32::from_str_radix(&path.file_name().unwrap().to_string_lossy(), 16)?;
let filename = &path
.file_name()
.expect("missing slru filename")
.to_string_lossy();
let segno = u32::from_str_radix(filename, 16)?;
let len = file.metadata().unwrap().len();
ensure!(len % pg_constants::BLCKSZ as u64 == 0); // we assume SLRU block size is the same as BLCKSZ
let nblocks = len / pg_constants::BLCKSZ as u64;
ensure!(len % pg_constants::BLCKSZ as usize == 0); // we assume SLRU block size is the same as BLCKSZ
let nblocks = len / pg_constants::BLCKSZ as usize;
ensure!(nblocks <= pg_constants::SLRU_PAGES_PER_SEGMENT as u64);
ensure!(nblocks <= pg_constants::SLRU_PAGES_PER_SEGMENT as usize);
modification.put_slru_segment_creation(slru, segno, nblocks as u32)?;
let mut rpageno = 0;
loop {
let r = file.read_exact(&mut buf);
let r = reader.read_exact(&mut buf);
match r {
Ok(_) => {
modification.put_slru_page_image(
@@ -396,10 +285,258 @@ fn import_wal<R: Repository>(
}
if last_lsn != startpoint {
debug!("reached end of WAL at {}", last_lsn);
info!("reached end of WAL at {}", last_lsn);
} else {
info!("no WAL to import at {}", last_lsn);
}
Ok(())
}
pub fn import_basebackup_from_tar<R: Repository, Reader: Read>(
tline: &mut DatadirTimeline<R>,
reader: Reader,
base_lsn: Lsn,
) -> Result<()> {
info!("importing base at {}", base_lsn);
let mut modification = tline.begin_modification(base_lsn);
modification.init_empty()?;
let mut pg_control: Option<ControlFileData> = None;
// Import base
for base_tar_entry in tar::Archive::new(reader).entries()? {
let entry = base_tar_entry?;
let header = entry.header();
let len = header.entry_size()? as usize;
let file_path = header.path()?.into_owned();
match header.entry_type() {
tar::EntryType::Regular => {
if let Some(res) = import_file(&mut modification, file_path.as_ref(), entry, len)? {
// We found the pg_control file.
pg_control = Some(res);
}
}
tar::EntryType::Directory => {
debug!("directory {:?}", file_path);
}
_ => {
panic!("tar::EntryType::?? {}", file_path.display());
}
}
}
// sanity check: ensure that pg_control is loaded
let _pg_control = pg_control.context("pg_control file not found")?;
modification.commit()?;
Ok(())
}
pub fn import_wal_from_tar<R: Repository, Reader: Read>(
tline: &mut DatadirTimeline<R>,
reader: Reader,
start_lsn: Lsn,
end_lsn: Lsn,
) -> Result<()> {
// Set up walingest mutable state
let mut waldecoder = WalStreamDecoder::new(start_lsn);
let mut segno = start_lsn.segment_number(pg_constants::WAL_SEGMENT_SIZE);
let mut offset = start_lsn.segment_offset(pg_constants::WAL_SEGMENT_SIZE);
let mut last_lsn = start_lsn;
let mut walingest = WalIngest::new(tline, start_lsn)?;
// Ingest wal until end_lsn
info!("importing wal until {}", end_lsn);
let mut pg_wal_tar = tar::Archive::new(reader);
let mut pg_wal_entries_iter = pg_wal_tar.entries()?;
while last_lsn <= end_lsn {
let bytes = {
let entry = pg_wal_entries_iter.next().expect("expected more wal")?;
let header = entry.header();
let file_path = header.path()?.into_owned();
match header.entry_type() {
tar::EntryType::Regular => {
// FIXME: assume postgresql tli 1 for now
let expected_filename = XLogFileName(1, segno, pg_constants::WAL_SEGMENT_SIZE);
let file_name = file_path
.file_name()
.expect("missing wal filename")
.to_string_lossy();
ensure!(expected_filename == file_name);
debug!("processing wal file {:?}", file_path);
read_all_bytes(entry)?
}
tar::EntryType::Directory => {
debug!("directory {:?}", file_path);
continue;
}
_ => {
panic!("tar::EntryType::?? {}", file_path.display());
}
}
};
waldecoder.feed_bytes(&bytes[offset..]);
while last_lsn <= end_lsn {
if let Some((lsn, recdata)) = waldecoder.poll_decode()? {
walingest.ingest_record(tline, recdata, lsn)?;
last_lsn = lsn;
debug!("imported record at {} (end {})", lsn, end_lsn);
}
}
debug!("imported records up to {}", last_lsn);
segno += 1;
offset = 0;
}
if last_lsn != start_lsn {
info!("reached end of WAL at {}", last_lsn);
} else {
info!("there was no WAL to import at {}", last_lsn);
}
// Log any extra unused files
for e in &mut pg_wal_entries_iter {
let entry = e?;
let header = entry.header();
let file_path = header.path()?.into_owned();
info!("skipping {:?}", file_path);
}
Ok(())
}
pub fn import_file<R: Repository, Reader: Read>(
modification: &mut DatadirModification<R>,
file_path: &Path,
reader: Reader,
len: usize,
) -> Result<Option<ControlFileData>> {
debug!("looking at {:?}", file_path);
if file_path.starts_with("global") {
let spcnode = pg_constants::GLOBALTABLESPACE_OID;
let dbnode = 0;
match file_path
.file_name()
.expect("missing filename")
.to_string_lossy()
.as_ref()
{
"pg_control" => {
let bytes = read_all_bytes(reader)?;
// Extract the checkpoint record and import it separately.
let pg_control = ControlFileData::decode(&bytes[..])?;
let checkpoint_bytes = pg_control.checkPointCopy.encode()?;
modification.put_checkpoint(checkpoint_bytes)?;
debug!("imported control file");
// Import it as ControlFile
modification.put_control_file(bytes)?;
return Ok(Some(pg_control));
}
"pg_filenode.map" => {
let bytes = read_all_bytes(reader)?;
modification.put_relmap_file(spcnode, dbnode, bytes)?;
debug!("imported relmap file")
}
"PG_VERSION" => {
debug!("ignored");
}
_ => {
import_rel(modification, file_path, spcnode, dbnode, reader, len)?;
debug!("imported rel creation");
}
}
} else if file_path.starts_with("base") {
let spcnode = pg_constants::DEFAULTTABLESPACE_OID;
let dbnode: u32 = file_path
.iter()
.nth(1)
.expect("invalid file path, expected dbnode")
.to_string_lossy()
.parse()?;
match file_path
.file_name()
.expect("missing base filename")
.to_string_lossy()
.as_ref()
{
"pg_filenode.map" => {
let bytes = read_all_bytes(reader)?;
modification.put_relmap_file(spcnode, dbnode, bytes)?;
debug!("imported relmap file")
}
"PG_VERSION" => {
debug!("ignored");
}
_ => {
import_rel(modification, file_path, spcnode, dbnode, reader, len)?;
debug!("imported rel creation");
}
}
} else if file_path.starts_with("pg_xact") {
let slru = SlruKind::Clog;
import_slru(modification, slru, file_path, reader, len)?;
debug!("imported clog slru");
} else if file_path.starts_with("pg_multixact/offsets") {
let slru = SlruKind::MultiXactOffsets;
import_slru(modification, slru, file_path, reader, len)?;
debug!("imported multixact offsets slru");
} else if file_path.starts_with("pg_multixact/members") {
let slru = SlruKind::MultiXactMembers;
import_slru(modification, slru, file_path, reader, len)?;
debug!("imported multixact members slru");
} else if file_path.starts_with("pg_twophase") {
let file_name = &file_path
.file_name()
.expect("missing twophase filename")
.to_string_lossy();
let xid = u32::from_str_radix(file_name, 16)?;
let bytes = read_all_bytes(reader)?;
modification.put_twophase_file(xid, Bytes::copy_from_slice(&bytes[..]))?;
debug!("imported twophase file");
} else if file_path.starts_with("pg_wal") {
debug!("found wal file in base section. ignore it");
} else if file_path.starts_with("zenith.signal") {
// Parse zenith signal file to set correct previous LSN
let bytes = read_all_bytes(reader)?;
// zenith.signal format is "PREV LSN: prev_lsn"
let zenith_signal = std::str::from_utf8(&bytes)?;
let zenith_signal = zenith_signal.split(':').collect::<Vec<_>>();
let prev_lsn = zenith_signal[1].trim().parse::<Lsn>()?;
let writer = modification.tline.tline.writer();
writer.finish_write(prev_lsn);
debug!("imported zenith signal {}", prev_lsn);
} else if file_path.starts_with("pg_tblspc") {
// TODO Backups exported from neon won't have pg_tblspc, but we will need
// this to import arbitrary postgres databases.
bail!("Importing pg_tblspc is not implemented");
} else {
debug!("ignored");
}
Ok(None)
}
fn read_all_bytes<Reader: Read>(mut reader: Reader) -> Result<Bytes> {
let mut buf: Vec<u8> = vec![];
reader.read_to_end(&mut buf)?;
Ok(Bytes::copy_from_slice(&buf[..]))
}

View File

@@ -243,15 +243,15 @@ impl Repository for LayeredRepository {
);
timeline.layers.write().unwrap().next_open_layer_at = Some(initdb_lsn);
// Insert if not exists
let timeline = Arc::new(timeline);
let r = timelines.insert(
timelineid,
LayeredTimelineEntry::Loaded(Arc::clone(&timeline)),
);
ensure!(
r.is_none(),
"assertion failure, inserted duplicate timeline"
);
match timelines.entry(timelineid) {
Entry::Occupied(_) => bail!("Timeline already exists"),
Entry::Vacant(vacant) => {
vacant.insert(LayeredTimelineEntry::Loaded(Arc::clone(&timeline)))
}
};
Ok(timeline)
}

View File

@@ -36,13 +36,12 @@
//! mapping is automatically removed and the slot is marked free.
//!
use parking_lot::{RwLock, RwLockReadGuard, RwLockWriteGuard};
use std::{
collections::{hash_map::Entry, HashMap},
convert::TryInto,
sync::{
atomic::{AtomicU8, AtomicUsize, Ordering},
RwLock, RwLockReadGuard, RwLockWriteGuard, TryLockError,
},
sync::atomic::{AtomicU8, AtomicUsize, Ordering},
};
use once_cell::sync::OnceCell;
@@ -385,7 +384,7 @@ impl PageCache {
for slot_idx in 0..self.slots.len() {
let slot = &self.slots[slot_idx];
let mut inner = slot.inner.write().unwrap();
let mut inner = slot.inner.write();
if let Some(key) = &inner.key {
match key {
CacheKey::EphemeralPage { file_id, blkno: _ } if *file_id == drop_file_id => {
@@ -413,7 +412,7 @@ impl PageCache {
for slot_idx in 0..self.slots.len() {
let slot = &self.slots[slot_idx];
let mut inner = slot.inner.write().unwrap();
let mut inner = slot.inner.write();
if let Some(key) = &inner.key {
match key {
CacheKey::ImmutableFilePage { file_id, blkno: _ }
@@ -454,7 +453,7 @@ impl PageCache {
// that it's still what we expected (because we released the mapping
// lock already, another thread could have evicted the page)
let slot = &self.slots[slot_idx];
let inner = slot.inner.read().unwrap();
let inner = slot.inner.read();
if inner.key.as_ref() == Some(cache_key) {
slot.inc_usage_count();
return Some(PageReadGuard(inner));
@@ -543,7 +542,7 @@ impl PageCache {
// that it's still what we expected (because we don't released the mapping
// lock already, another thread could have evicted the page)
let slot = &self.slots[slot_idx];
let inner = slot.inner.write().unwrap();
let inner = slot.inner.write();
if inner.key.as_ref() == Some(cache_key) {
slot.inc_usage_count();
return Some(PageWriteGuard { inner, valid: true });
@@ -611,7 +610,7 @@ impl PageCache {
fn search_mapping(&self, cache_key: &mut CacheKey) -> Option<usize> {
match cache_key {
CacheKey::MaterializedPage { hash_key, lsn } => {
let map = self.materialized_page_map.read().unwrap();
let map = self.materialized_page_map.read();
let versions = map.get(hash_key)?;
let version_idx = match versions.binary_search_by_key(lsn, |v| v.lsn) {
@@ -624,11 +623,11 @@ impl PageCache {
Some(version.slot_idx)
}
CacheKey::EphemeralPage { file_id, blkno } => {
let map = self.ephemeral_page_map.read().unwrap();
let map = self.ephemeral_page_map.read();
Some(*map.get(&(*file_id, *blkno))?)
}
CacheKey::ImmutableFilePage { file_id, blkno } => {
let map = self.immutable_page_map.read().unwrap();
let map = self.immutable_page_map.read();
Some(*map.get(&(*file_id, *blkno))?)
}
}
@@ -641,7 +640,7 @@ impl PageCache {
fn search_mapping_for_write(&self, key: &CacheKey) -> Option<usize> {
match key {
CacheKey::MaterializedPage { hash_key, lsn } => {
let map = self.materialized_page_map.read().unwrap();
let map = self.materialized_page_map.read();
let versions = map.get(hash_key)?;
if let Ok(version_idx) = versions.binary_search_by_key(lsn, |v| v.lsn) {
@@ -651,11 +650,11 @@ impl PageCache {
}
}
CacheKey::EphemeralPage { file_id, blkno } => {
let map = self.ephemeral_page_map.read().unwrap();
let map = self.ephemeral_page_map.read();
Some(*map.get(&(*file_id, *blkno))?)
}
CacheKey::ImmutableFilePage { file_id, blkno } => {
let map = self.immutable_page_map.read().unwrap();
let map = self.immutable_page_map.read();
Some(*map.get(&(*file_id, *blkno))?)
}
}
@@ -670,7 +669,7 @@ impl PageCache {
hash_key: old_hash_key,
lsn: old_lsn,
} => {
let mut map = self.materialized_page_map.write().unwrap();
let mut map = self.materialized_page_map.write();
if let Entry::Occupied(mut old_entry) = map.entry(old_hash_key.clone()) {
let versions = old_entry.get_mut();
@@ -685,12 +684,12 @@ impl PageCache {
}
}
CacheKey::EphemeralPage { file_id, blkno } => {
let mut map = self.ephemeral_page_map.write().unwrap();
let mut map = self.ephemeral_page_map.write();
map.remove(&(*file_id, *blkno))
.expect("could not find old key in mapping");
}
CacheKey::ImmutableFilePage { file_id, blkno } => {
let mut map = self.immutable_page_map.write().unwrap();
let mut map = self.immutable_page_map.write();
map.remove(&(*file_id, *blkno))
.expect("could not find old key in mapping");
}
@@ -708,7 +707,7 @@ impl PageCache {
hash_key: new_key,
lsn: new_lsn,
} => {
let mut map = self.materialized_page_map.write().unwrap();
let mut map = self.materialized_page_map.write();
let versions = map.entry(new_key.clone()).or_default();
match versions.binary_search_by_key(new_lsn, |v| v.lsn) {
Ok(version_idx) => Some(versions[version_idx].slot_idx),
@@ -725,7 +724,7 @@ impl PageCache {
}
}
CacheKey::EphemeralPage { file_id, blkno } => {
let mut map = self.ephemeral_page_map.write().unwrap();
let mut map = self.ephemeral_page_map.write();
match map.entry((*file_id, *blkno)) {
Entry::Occupied(entry) => Some(*entry.get()),
Entry::Vacant(entry) => {
@@ -735,7 +734,7 @@ impl PageCache {
}
}
CacheKey::ImmutableFilePage { file_id, blkno } => {
let mut map = self.immutable_page_map.write().unwrap();
let mut map = self.immutable_page_map.write();
match map.entry((*file_id, *blkno)) {
Entry::Occupied(entry) => Some(*entry.get()),
Entry::Vacant(entry) => {
@@ -765,11 +764,8 @@ impl PageCache {
if slot.dec_usage_count() == 0 {
let mut inner = match slot.inner.try_write() {
Ok(inner) => inner,
Err(TryLockError::Poisoned(err)) => {
panic!("buffer lock was poisoned: {:?}", err)
}
Err(TryLockError::WouldBlock) => {
Some(inner) => inner,
None => {
// If we have looped through the whole buffer pool 10 times
// and still haven't found a victim buffer, something's wrong.
// Maybe all the buffers were in locked. That could happen in

View File

@@ -13,7 +13,7 @@ use anyhow::{bail, ensure, Context, Result};
use bytes::{Buf, BufMut, Bytes, BytesMut};
use lazy_static::lazy_static;
use regex::Regex;
use std::io;
use std::io::{self, Read};
use std::net::TcpListener;
use std::str;
use std::str::FromStr;
@@ -29,6 +29,8 @@ use utils::{
use crate::basebackup;
use crate::config::{PageServerConf, ProfilingConfig};
use crate::import_datadir::{import_basebackup_from_tar, import_wal_from_tar};
use crate::layered_repository::LayeredRepository;
use crate::pgdatadir_mapping::{DatadirTimeline, LsnForTimestamp};
use crate::profiling::profpoint_start;
use crate::reltag::RelTag;
@@ -200,6 +202,96 @@ impl PagestreamBeMessage {
}
}
/// Implements Read for the server side of CopyIn
struct CopyInReader<'a> {
pgb: &'a mut PostgresBackend,
/// Overflow buffer for bytes sent in CopyData messages
/// that the reader (caller of read) hasn't asked for yet.
/// TODO use BytesMut?
buf: Vec<u8>,
/// Bytes before `buf_begin` are considered as dropped.
/// This allows us to implement O(1) pop_front on Vec<u8>.
/// The Vec won't grow large because we only add to it
/// when it's empty.
buf_begin: usize,
}
impl<'a> CopyInReader<'a> {
// NOTE: pgb should be in copy in state already
fn new(pgb: &'a mut PostgresBackend) -> Self {
Self {
pgb,
buf: Vec::<_>::new(),
buf_begin: 0,
}
}
}
impl<'a> Drop for CopyInReader<'a> {
fn drop(&mut self) {
// Finalize copy protocol so that self.pgb can be reused
// TODO instead, maybe take ownership of pgb and give it back at the end
let mut buf: Vec<u8> = vec![];
let _ = self.read_to_end(&mut buf);
}
}
impl<'a> Read for CopyInReader<'a> {
fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
while !thread_mgr::is_shutdown_requested() {
// Return from buffer if nonempty
if self.buf_begin < self.buf.len() {
let bytes_to_read = std::cmp::min(buf.len(), self.buf.len() - self.buf_begin);
buf[..bytes_to_read].copy_from_slice(&self.buf[self.buf_begin..][..bytes_to_read]);
self.buf_begin += bytes_to_read;
return Ok(bytes_to_read);
}
// Delete garbage
self.buf.clear();
self.buf_begin = 0;
// Wait for client to send CopyData bytes
match self.pgb.read_message() {
Ok(Some(message)) => {
let copy_data_bytes = match message {
FeMessage::CopyData(bytes) => bytes,
FeMessage::CopyDone => return Ok(0),
FeMessage::Sync => continue,
m => {
let msg = format!("unexpected message {:?}", m);
self.pgb.write_message(&BeMessage::ErrorResponse(&msg))?;
return Err(io::Error::new(io::ErrorKind::Other, msg));
}
};
// Return as much as we can, saving the rest in self.buf
let mut reader = copy_data_bytes.reader();
let bytes_read = reader.read(buf)?;
reader.read_to_end(&mut self.buf)?;
return Ok(bytes_read);
}
Ok(None) => {
let msg = "client closed connection";
self.pgb.write_message(&BeMessage::ErrorResponse(msg))?;
return Err(io::Error::new(io::ErrorKind::Other, msg));
}
Err(e) => {
if !is_socket_read_timed_out(&e) {
return Err(io::Error::new(io::ErrorKind::Other, e));
}
}
}
}
// Shutting down
let msg = "Importer thread was shut down";
Err(io::Error::new(io::ErrorKind::Other, msg))
}
}
///////////////////////////////////////////////////////////////////////////////
///
@@ -447,6 +539,98 @@ impl PageServerHandler {
Ok(())
}
fn handle_import_basebackup(
&self,
pgb: &mut PostgresBackend,
tenant_id: ZTenantId,
timeline_id: ZTimelineId,
base_lsn: Lsn,
_end_lsn: Lsn,
) -> anyhow::Result<()> {
thread_mgr::associate_with(Some(tenant_id), Some(timeline_id));
let _enter =
info_span!("import basebackup", timeline = %timeline_id, tenant = %tenant_id).entered();
// Create empty timeline
info!("creating new timeline");
let repo = tenant_mgr::get_repository_for_tenant(tenant_id)?;
let timeline = repo.create_empty_timeline(timeline_id, Lsn(0))?;
let repartition_distance = repo.get_checkpoint_distance();
let mut datadir_timeline =
DatadirTimeline::<LayeredRepository>::new(timeline, repartition_distance);
// TODO mark timeline as not ready until it reaches end_lsn.
// We might have some wal to import as well, and we should prevent compute
// from connecting before that and writing conflicting wal.
//
// This is not relevant for pageserver->pageserver migrations, since there's
// no wal to import. But should be fixed if we want to import from postgres.
// TODO leave clean state on error. For now you can use detach to clean
// up broken state from a failed import.
// Import basebackup provided via CopyData
info!("importing basebackup");
pgb.write_message(&BeMessage::CopyInResponse)?;
let reader = CopyInReader::new(pgb);
import_basebackup_from_tar(&mut datadir_timeline, reader, base_lsn)?;
// TODO check checksum
// Meanwhile you can verify client-side by taking fullbackup
// and checking that it matches in size with what was imported.
// It wouldn't work if base came from vanilla postgres though,
// since we discard some log files.
// Flush data to disk, then upload to s3
info!("flushing layers");
datadir_timeline.tline.checkpoint(CheckpointConfig::Flush)?;
info!("done");
Ok(())
}
fn handle_import_wal(
&self,
pgb: &mut PostgresBackend,
tenant_id: ZTenantId,
timeline_id: ZTimelineId,
start_lsn: Lsn,
end_lsn: Lsn,
) -> anyhow::Result<()> {
thread_mgr::associate_with(Some(tenant_id), Some(timeline_id));
let _enter =
info_span!("import wal", timeline = %timeline_id, tenant = %tenant_id).entered();
let repo = tenant_mgr::get_repository_for_tenant(tenant_id)?;
let timeline = repo.get_timeline_load(timeline_id)?;
ensure!(timeline.get_last_record_lsn() == start_lsn);
let repartition_distance = repo.get_checkpoint_distance();
let mut datadir_timeline =
DatadirTimeline::<LayeredRepository>::new(timeline, repartition_distance);
// TODO leave clean state on error. For now you can use detach to clean
// up broken state from a failed import.
// Import wal provided via CopyData
info!("importing wal");
pgb.write_message(&BeMessage::CopyInResponse)?;
let reader = CopyInReader::new(pgb);
import_wal_from_tar(&mut datadir_timeline, reader, start_lsn, end_lsn)?;
// TODO Does it make sense to overshoot?
ensure!(datadir_timeline.tline.get_last_record_lsn() >= end_lsn);
// Flush data to disk, then upload to s3. No need for a forced checkpoint.
// We only want to persist the data, and it doesn't matter if it's in the
// shape of deltas or images.
info!("flushing layers");
datadir_timeline.tline.checkpoint(CheckpointConfig::Flush)?;
info!("done");
Ok(())
}
/// Helper function to handle the LSN from client request.
///
/// Each GetPage (and Exists and Nblocks) request includes information about
@@ -549,17 +733,10 @@ impl PageServerHandler {
let latest_gc_cutoff_lsn = timeline.tline.get_latest_gc_cutoff_lsn();
let lsn = Self::wait_or_get_last_lsn(timeline, req.lsn, req.latest, &latest_gc_cutoff_lsn)?;
let all_rels = timeline.list_rels(pg_constants::DEFAULTTABLESPACE_OID, req.dbnode, lsn)?;
let mut total_blocks: i64 = 0;
let total_blocks =
timeline.get_db_size(pg_constants::DEFAULTTABLESPACE_OID, req.dbnode, lsn)?;
for rel in all_rels {
if rel.forknum == 0 {
let n_blocks = timeline.get_rel_size(rel, lsn).unwrap_or(0);
total_blocks += n_blocks as i64;
}
}
let db_size = total_blocks * pg_constants::BLCKSZ as i64;
let db_size = total_blocks as i64 * pg_constants::BLCKSZ as i64;
Ok(PagestreamBeMessage::DbSize(PagestreamDbSizeResponse {
db_size,
@@ -596,6 +773,7 @@ impl PageServerHandler {
timelineid: ZTimelineId,
lsn: Option<Lsn>,
tenantid: ZTenantId,
full_backup: bool,
) -> anyhow::Result<()> {
let span = info_span!("basebackup", timeline = %timelineid, tenant = %tenantid, lsn = field::Empty);
let _enter = span.enter();
@@ -618,7 +796,7 @@ impl PageServerHandler {
{
let mut writer = CopyDataSink { pgb };
let basebackup = basebackup::Basebackup::new(&mut writer, &timeline, lsn)?;
let basebackup = basebackup::Basebackup::new(&mut writer, &timeline, lsn, full_backup)?;
span.record("lsn", &basebackup.lsn.to_string().as_str());
basebackup.send_tarball()?;
}
@@ -721,8 +899,79 @@ impl postgres_backend::Handler for PageServerHandler {
};
// Check that the timeline exists
self.handle_basebackup_request(pgb, timelineid, lsn, tenantid)?;
self.handle_basebackup_request(pgb, timelineid, lsn, tenantid, false)?;
pgb.write_message_noflush(&BeMessage::CommandComplete(b"SELECT 1"))?;
}
// same as basebackup, but result includes relational data as well
else if query_string.starts_with("fullbackup ") {
let (_, params_raw) = query_string.split_at("fullbackup ".len());
let params = params_raw.split_whitespace().collect::<Vec<_>>();
ensure!(
params.len() == 3,
"invalid param number for fullbackup command"
);
let tenantid = ZTenantId::from_str(params[0])?;
let timelineid = ZTimelineId::from_str(params[1])?;
self.check_permission(Some(tenantid))?;
// Lsn is required for fullbackup, because otherwise we would not know
// at which lsn to upload this backup.
//
// The caller is responsible for providing a valid lsn
// and using it in the subsequent import.
let lsn = Some(Lsn::from_str(params[2])?);
// Check that the timeline exists
self.handle_basebackup_request(pgb, timelineid, lsn, tenantid, true)?;
pgb.write_message_noflush(&BeMessage::CommandComplete(b"SELECT 1"))?;
} else if query_string.starts_with("import basebackup ") {
// Import the `base` section (everything but the wal) of a basebackup.
// Assumes the tenant already exists on this pageserver.
//
// Files are scheduled to be persisted to remote storage, and the
// caller should poll the http api to check when that is done.
//
// Example import command:
// 1. Get start/end LSN from backup_manifest file
// 2. Run:
// cat my_backup/base.tar | psql -h $PAGESERVER \
// -c "import basebackup $TENANT $TIMELINE $START_LSN $END_LSN"
let (_, params_raw) = query_string.split_at("import basebackup ".len());
let params = params_raw.split_whitespace().collect::<Vec<_>>();
ensure!(params.len() == 4);
let tenant = ZTenantId::from_str(params[0])?;
let timeline = ZTimelineId::from_str(params[1])?;
let base_lsn = Lsn::from_str(params[2])?;
let end_lsn = Lsn::from_str(params[3])?;
self.check_permission(Some(tenant))?;
match self.handle_import_basebackup(pgb, tenant, timeline, base_lsn, end_lsn) {
Ok(()) => pgb.write_message_noflush(&BeMessage::CommandComplete(b"SELECT 1"))?,
Err(e) => pgb.write_message_noflush(&BeMessage::ErrorResponse(&e.to_string()))?,
};
} else if query_string.starts_with("import wal ") {
// Import the `pg_wal` section of a basebackup.
//
// Files are scheduled to be persisted to remote storage, and the
// caller should poll the http api to check when that is done.
let (_, params_raw) = query_string.split_at("import wal ".len());
let params = params_raw.split_whitespace().collect::<Vec<_>>();
ensure!(params.len() == 4);
let tenant = ZTenantId::from_str(params[0])?;
let timeline = ZTimelineId::from_str(params[1])?;
let start_lsn = Lsn::from_str(params[2])?;
let end_lsn = Lsn::from_str(params[3])?;
self.check_permission(Some(tenant))?;
match self.handle_import_wal(pgb, tenant, timeline, start_lsn, end_lsn) {
Ok(()) => pgb.write_message_noflush(&BeMessage::CommandComplete(b"SELECT 1"))?,
Err(e) => pgb.write_message_noflush(&BeMessage::ErrorResponse(&e.to_string()))?,
};
} else if query_string.to_ascii_lowercase().starts_with("set ") {
// important because psycopg2 executes "SET datestyle TO 'ISO'"
// on connect

View File

@@ -123,6 +123,19 @@ impl<R: Repository> DatadirTimeline<R> {
self.tline.get(key, lsn)
}
// Get size of a database in blocks
pub fn get_db_size(&self, spcnode: Oid, dbnode: Oid, lsn: Lsn) -> Result<usize> {
let mut total_blocks = 0;
let rels = self.list_rels(spcnode, dbnode, lsn)?;
for rel in rels {
let n_blocks = self.get_rel_size(rel, lsn)?;
total_blocks += n_blocks as usize;
}
Ok(total_blocks)
}
/// Get size of a relation file
pub fn get_rel_size(&self, tag: RelTag, lsn: Lsn) -> Result<BlockNumber> {
ensure!(tag.relnode != 0, "invalid relnode");
@@ -667,6 +680,10 @@ impl<'a, R: Repository> DatadirModification<'a, R> {
}
pub fn drop_dbdir(&mut self, spcnode: Oid, dbnode: Oid) -> Result<()> {
let req_lsn = self.tline.get_last_record_lsn();
let total_blocks = self.tline.get_db_size(spcnode, dbnode, req_lsn)?;
// Remove entry from dbdir
let buf = self.get(DBDIR_KEY)?;
let mut dir = DbDirectory::des(&buf)?;
@@ -680,7 +697,8 @@ impl<'a, R: Repository> DatadirModification<'a, R> {
);
}
// FIXME: update pending_nblocks
// Update logical database size.
self.pending_nblocks -= total_blocks as isize;
// Delete all relations and metadata files for the spcnode/dnode
self.delete(dbdir_key_range(spcnode, dbnode));
@@ -749,6 +767,7 @@ impl<'a, R: Repository> DatadirModification<'a, R> {
}
/// Extend relation
/// If new size is smaller, do nothing.
pub fn put_rel_extend(&mut self, rel: RelTag, nblocks: BlockNumber) -> Result<()> {
ensure!(rel.relnode != 0, "invalid relnode");
@@ -756,10 +775,13 @@ impl<'a, R: Repository> DatadirModification<'a, R> {
let size_key = rel_size_to_key(rel);
let old_size = self.get(size_key)?.get_u32_le();
let buf = nblocks.to_le_bytes();
self.put(size_key, Value::Image(Bytes::from(buf.to_vec())));
// only extend relation here. never decrease the size
if nblocks > old_size {
let buf = nblocks.to_le_bytes();
self.put(size_key, Value::Image(Bytes::from(buf.to_vec())));
self.pending_nblocks += nblocks as isize - old_size as isize;
self.pending_nblocks += nblocks as isize - old_size as isize;
}
Ok(())
}

View File

@@ -3,7 +3,7 @@ use std::cmp::Ordering;
use std::fmt;
use postgres_ffi::relfile_utils::forknumber_to_name;
use postgres_ffi::Oid;
use postgres_ffi::{pg_constants, Oid};
///
/// Relation data file segment id throughout the Postgres cluster.
@@ -75,6 +75,30 @@ impl fmt::Display for RelTag {
}
}
impl RelTag {
pub fn to_segfile_name(&self, segno: u32) -> String {
let mut name = if self.spcnode == pg_constants::GLOBALTABLESPACE_OID {
"global/".to_string()
} else {
format!("base/{}/", self.dbnode)
};
name += &self.relnode.to_string();
if let Some(fork_name) = forknumber_to_name(self.forknum) {
name += "_";
name += fork_name;
}
if segno != 0 {
name += ".";
name += &segno.to_string();
}
name
}
}
///
/// Non-relation transaction status files (clog (a.k.a. pg_xact) and
/// pg_multixact) in Postgres are handled by SLRU (Simple LRU) buffer,

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@@ -1,5 +1,5 @@
//! Actual Postgres connection handler to stream WAL to the server.
//! Runs as a separate, cancellable Tokio task.
use std::{
str::FromStr,
sync::Arc,
@@ -10,113 +10,29 @@ use anyhow::{bail, ensure, Context};
use bytes::BytesMut;
use fail::fail_point;
use postgres::{SimpleQueryMessage, SimpleQueryRow};
use postgres_ffi::waldecoder::WalStreamDecoder;
use postgres_protocol::message::backend::ReplicationMessage;
use postgres_types::PgLsn;
use tokio::{pin, select, sync::watch, time};
use tokio_postgres::{replication::ReplicationStream, Client};
use tokio_stream::StreamExt;
use tracing::{debug, error, info, info_span, trace, warn, Instrument};
use utils::{
lsn::Lsn,
pq_proto::ZenithFeedback,
zid::{NodeId, ZTenantTimelineId},
};
use super::TaskEvent;
use crate::{
http::models::WalReceiverEntry,
repository::{Repository, Timeline},
tenant_mgr,
walingest::WalIngest,
};
use postgres_ffi::waldecoder::WalStreamDecoder;
use utils::{lsn::Lsn, pq_proto::ReplicationFeedback, zid::ZTenantTimelineId};
#[derive(Debug, Clone)]
pub enum WalConnectionEvent {
Started,
NewWal(ZenithFeedback),
End(Result<(), String>),
}
/// A wrapper around standalone Tokio task, to poll its updates or cancel the task.
#[derive(Debug)]
pub struct WalReceiverConnection {
handle: tokio::task::JoinHandle<()>,
cancellation: watch::Sender<()>,
events_receiver: watch::Receiver<WalConnectionEvent>,
}
impl WalReceiverConnection {
/// Initializes the connection task, returning a set of handles on top of it.
/// The task is started immediately after the creation, fails if no connection is established during the timeout given.
pub fn open(
id: ZTenantTimelineId,
safekeeper_id: NodeId,
wal_producer_connstr: String,
connect_timeout: Duration,
) -> Self {
let (cancellation, mut cancellation_receiver) = watch::channel(());
let (events_sender, events_receiver) = watch::channel(WalConnectionEvent::Started);
let handle = tokio::spawn(
async move {
let connection_result = handle_walreceiver_connection(
id,
&wal_producer_connstr,
&events_sender,
&mut cancellation_receiver,
connect_timeout,
)
.await
.map_err(|e| {
format!("Walreceiver connection for id {id} failed with error: {e:#}")
});
match &connection_result {
Ok(()) => {
debug!("Walreceiver connection for id {id} ended successfully")
}
Err(e) => warn!("{e}"),
}
events_sender
.send(WalConnectionEvent::End(connection_result))
.ok();
}
.instrument(info_span!("safekeeper_handle", sk = %safekeeper_id)),
);
Self {
handle,
cancellation,
events_receiver,
}
}
/// Polls for the next WAL receiver event, if there's any available since the last check.
/// Blocks if there's no new event available, returns `None` if no new events will ever occur.
/// Only the last event is returned, all events received between observatins are lost.
pub async fn next_event(&mut self) -> Option<WalConnectionEvent> {
match self.events_receiver.changed().await {
Ok(()) => Some(self.events_receiver.borrow().clone()),
Err(_cancellation_error) => None,
}
}
/// Gracefully aborts current WAL streaming task, waiting for the current WAL streamed.
pub async fn shutdown(&mut self) -> anyhow::Result<()> {
self.cancellation.send(()).ok();
let handle = &mut self.handle;
handle
.await
.context("Failed to join on a walreceiver connection task")?;
Ok(())
}
}
async fn handle_walreceiver_connection(
/// Opens a conneciton to the given wal producer and streams the WAL, sending progress messages during streaming.
pub async fn handle_walreceiver_connection(
id: ZTenantTimelineId,
wal_producer_connstr: &str,
events_sender: &watch::Sender<WalConnectionEvent>,
cancellation: &mut watch::Receiver<()>,
events_sender: &watch::Sender<TaskEvent<ReplicationFeedback>>,
mut cancellation: watch::Receiver<()>,
connect_timeout: Duration,
) -> anyhow::Result<()> {
// Connect to the database in replication mode.
@@ -214,8 +130,6 @@ async fn handle_walreceiver_connection(
while let Some(replication_message) = {
select! {
// check for shutdown first
biased;
_ = cancellation.changed() => {
info!("walreceiver interrupted");
None
@@ -328,7 +242,7 @@ async fn handle_walreceiver_connection(
// Send zenith feedback message.
// Regular standby_status_update fields are put into this message.
let zenith_status_update = ZenithFeedback {
let zenith_status_update = ReplicationFeedback {
current_timeline_size: timeline.get_current_logical_size() as u64,
ps_writelsn: write_lsn,
ps_flushlsn: flush_lsn,
@@ -344,7 +258,7 @@ async fn handle_walreceiver_connection(
.as_mut()
.zenith_status_update(data.len() as u64, &data)
.await?;
if let Err(e) = events_sender.send(WalConnectionEvent::NewWal(zenith_status_update)) {
if let Err(e) = events_sender.send(TaskEvent::NewEvent(zenith_status_update)) {
warn!("Wal connection event listener dropped, aborting the connection: {e}");
return Ok(());
}

View File

@@ -39,6 +39,8 @@ utils = { path = "../libs/utils" }
metrics = { path = "../libs/metrics" }
workspace_hack = { version = "0.1", path = "../workspace_hack" }
x509-parser = "0.13.2"
[dev-dependencies]
rcgen = "0.8.14"
rstest = "0.12"

View File

@@ -19,7 +19,7 @@ pub type Result<T> = std::result::Result<T, ConsoleAuthError>;
#[derive(Debug, Error)]
pub enum ConsoleAuthError {
#[error(transparent)]
BadProjectName(#[from] auth::credentials::ProjectNameError),
BadProjectName(#[from] auth::credentials::ClientCredsParseError),
// We shouldn't include the actual secret here.
#[error("Bad authentication secret")]
@@ -49,6 +49,12 @@ impl UserFacingError for ConsoleAuthError {
}
}
impl From<&auth::credentials::ClientCredsParseError> for ConsoleAuthError {
fn from(e: &auth::credentials::ClientCredsParseError) -> Self {
ConsoleAuthError::BadProjectName(e.clone())
}
}
// TODO: convert into an enum with "error"
#[derive(Serialize, Deserialize, Debug)]
struct GetRoleSecretResponse {
@@ -74,18 +80,12 @@ pub enum AuthInfo {
pub(super) struct Api<'a> {
endpoint: &'a ApiUrl,
creds: &'a ClientCredentials,
/// Cache project name, since we'll need it several times.
project: &'a str,
}
impl<'a> Api<'a> {
/// Construct an API object containing the auth parameters.
pub(super) fn new(endpoint: &'a ApiUrl, creds: &'a ClientCredentials) -> Result<Self> {
Ok(Self {
endpoint,
creds,
project: creds.project_name()?,
})
Ok(Self { endpoint, creds })
}
/// Authenticate the existing user or throw an error.
@@ -100,7 +100,7 @@ impl<'a> Api<'a> {
let mut url = self.endpoint.clone();
url.path_segments_mut().push("proxy_get_role_secret");
url.query_pairs_mut()
.append_pair("project", self.project)
.append_pair("project", self.creds.project_name.as_ref()?)
.append_pair("role", &self.creds.user);
// TODO: use a proper logger
@@ -123,7 +123,8 @@ impl<'a> Api<'a> {
async fn wake_compute(&self) -> Result<DatabaseInfo> {
let mut url = self.endpoint.clone();
url.path_segments_mut().push("proxy_wake_compute");
url.query_pairs_mut().append_pair("project", self.project);
let project_name = self.creds.project_name.as_ref()?;
url.query_pairs_mut().append_pair("project", project_name);
// TODO: use a proper logger
println!("cplane request: {url}");

View File

@@ -8,10 +8,32 @@ use std::collections::HashMap;
use thiserror::Error;
use tokio::io::{AsyncRead, AsyncWrite};
#[derive(Debug, Error)]
#[derive(Debug, Error, PartialEq, Eq, Clone)]
pub enum ClientCredsParseError {
#[error("Parameter `{0}` is missing in startup packet")]
#[error("Parameter `{0}` is missing in startup packet.")]
MissingKey(&'static str),
#[error(
"Project name is not specified. \
EITHER please upgrade the postgres client library (libpq) for SNI support \
OR pass the project name as a parameter: '&options=project%3D<project-name>'."
)]
MissingSNIAndProjectName,
#[error("Inconsistent project name inferred from SNI ('{0}') and project option ('{1}').")]
InconsistentProjectNameAndSNI(String, String),
#[error("Common name is not set.")]
CommonNameNotSet,
#[error(
"SNI ('{1}') inconsistently formatted with respect to common name ('{0}'). \
SNI should be formatted as '<project-name>.<common-name>'."
)]
InconsistentCommonNameAndSNI(String, String),
#[error("Project name ('{0}') must contain only alphanumeric characters and hyphens ('-').")]
ProjectNameContainsIllegalChars(String),
}
impl UserFacingError for ClientCredsParseError {}
@@ -22,15 +44,7 @@ impl UserFacingError for ClientCredsParseError {}
pub struct ClientCredentials {
pub user: String,
pub dbname: String,
// New console API requires SNI info to determine the cluster name.
// Other Auth backends don't need it.
pub sni_data: Option<String>,
// project_name is passed as argument from options from url.
// In case sni_data is missing: project_name is used to determine cluster name.
// In case sni_data is available: project_name and sni_data should match (otherwise throws an error).
pub project_name: Option<String>,
pub project_name: Result<String, ClientCredsParseError>,
}
impl ClientCredentials {
@@ -38,60 +52,14 @@ impl ClientCredentials {
// This logic will likely change in the future.
self.user.ends_with("@zenith")
}
}
#[derive(Debug, Error)]
pub enum ProjectNameError {
#[error("SNI is missing. EITHER please upgrade the postgres client library OR pass the project name as a parameter: '...&options=project%3D<project-name>...'.")]
Missing,
#[error("SNI is malformed.")]
Bad,
#[error("Inconsistent project name inferred from SNI and project option. String from SNI: '{0}', String from project option: '{1}'")]
Inconsistent(String, String),
}
impl UserFacingError for ProjectNameError {}
impl ClientCredentials {
/// Determine project name from SNI or from project_name parameter from options argument.
pub fn project_name(&self) -> Result<&str, ProjectNameError> {
// Checking that if both sni_data and project_name are set, then they should match
// otherwise, throws a ProjectNameError::Inconsistent error.
if let Some(sni_data) = &self.sni_data {
let project_name_from_sni_data =
sni_data.split_once('.').ok_or(ProjectNameError::Bad)?.0;
if let Some(project_name_from_options) = &self.project_name {
if !project_name_from_options.eq(project_name_from_sni_data) {
return Err(ProjectNameError::Inconsistent(
project_name_from_sni_data.to_string(),
project_name_from_options.to_string(),
));
}
}
}
// determine the project name from self.sni_data if it exists, otherwise from self.project_name.
let ret = match &self.sni_data {
// if sni_data exists, use it to determine project name
Some(sni_data) => sni_data.split_once('.').ok_or(ProjectNameError::Bad)?.0,
// otherwise use project_option if it was manually set thought options parameter.
None => self
.project_name
.as_ref()
.ok_or(ProjectNameError::Missing)?
.as_str(),
};
Ok(ret)
}
}
impl TryFrom<HashMap<String, String>> for ClientCredentials {
type Error = ClientCredsParseError;
fn try_from(mut value: HashMap<String, String>) -> Result<Self, Self::Error> {
pub fn parse(
mut options: HashMap<String, String>,
sni_data: Option<&str>,
common_name: Option<&str>,
) -> Result<Self, ClientCredsParseError> {
let mut get_param = |key| {
value
options
.remove(key)
.ok_or(ClientCredsParseError::MissingKey(key))
};
@@ -99,17 +67,15 @@ impl TryFrom<HashMap<String, String>> for ClientCredentials {
let user = get_param("user")?;
let dbname = get_param("database")?;
let project_name = get_param("project").ok();
let project_name = get_project_name(sni_data, common_name, project_name.as_deref());
Ok(Self {
user,
dbname,
sni_data: None,
project_name,
})
}
}
impl ClientCredentials {
/// Use credentials to authenticate the user.
pub async fn authenticate(
self,
@@ -120,3 +86,244 @@ impl ClientCredentials {
super::backend::handle_user(config, client, self).await
}
}
/// Inferring project name from sni_data.
fn project_name_from_sni_data(
sni_data: &str,
common_name: &str,
) -> Result<String, ClientCredsParseError> {
let common_name_with_dot = format!(".{common_name}");
// check that ".{common_name_with_dot}" is the actual suffix in sni_data
if !sni_data.ends_with(&common_name_with_dot) {
return Err(ClientCredsParseError::InconsistentCommonNameAndSNI(
common_name.to_string(),
sni_data.to_string(),
));
}
// return sni_data without the common name suffix.
Ok(sni_data
.strip_suffix(&common_name_with_dot)
.unwrap()
.to_string())
}
#[cfg(test)]
mod tests_for_project_name_from_sni_data {
use super::*;
#[test]
fn passing() {
let target_project_name = "my-project-123";
let common_name = "localtest.me";
let sni_data = format!("{target_project_name}.{common_name}");
assert_eq!(
project_name_from_sni_data(&sni_data, common_name),
Ok(target_project_name.to_string())
);
}
#[test]
fn throws_inconsistent_common_name_and_sni_data() {
let target_project_name = "my-project-123";
let common_name = "localtest.me";
let wrong_suffix = "wrongtest.me";
assert_eq!(common_name.len(), wrong_suffix.len());
let wrong_common_name = format!("wrong{wrong_suffix}");
let sni_data = format!("{target_project_name}.{wrong_common_name}");
assert_eq!(
project_name_from_sni_data(&sni_data, common_name),
Err(ClientCredsParseError::InconsistentCommonNameAndSNI(
common_name.to_string(),
sni_data
))
);
}
}
/// Determine project name from SNI or from project_name parameter from options argument.
fn get_project_name(
sni_data: Option<&str>,
common_name: Option<&str>,
project_name: Option<&str>,
) -> Result<String, ClientCredsParseError> {
// determine the project name from sni_data if it exists, otherwise from project_name.
let ret = match sni_data {
Some(sni_data) => {
let common_name = common_name.ok_or(ClientCredsParseError::CommonNameNotSet)?;
let project_name_from_sni = project_name_from_sni_data(sni_data, common_name)?;
// check invariant: project name from options and from sni should match
if let Some(project_name) = &project_name {
if !project_name_from_sni.eq(project_name) {
return Err(ClientCredsParseError::InconsistentProjectNameAndSNI(
project_name_from_sni,
project_name.to_string(),
));
}
}
project_name_from_sni
}
None => project_name
.ok_or(ClientCredsParseError::MissingSNIAndProjectName)?
.to_string(),
};
// check formatting invariant: project name must contain only alphanumeric characters and hyphens.
if !ret.chars().all(|x: char| x.is_alphanumeric() || x == '-') {
return Err(ClientCredsParseError::ProjectNameContainsIllegalChars(ret));
}
Ok(ret)
}
#[cfg(test)]
mod tests_for_project_name_only {
use super::*;
#[test]
fn passing_from_sni_data_only() {
let target_project_name = "my-project-123";
let common_name = "localtest.me";
let sni_data = format!("{target_project_name}.{common_name}");
assert_eq!(
get_project_name(Some(&sni_data), Some(common_name), None),
Ok(target_project_name.to_string())
);
}
#[test]
fn throws_project_name_contains_illegal_chars_from_sni_data_only() {
let project_name_prefix = "my-project";
let project_name_suffix = "123";
let common_name = "localtest.me";
for illegal_char_id in 0..256 {
let illegal_char = char::from_u32(illegal_char_id).unwrap();
if !(illegal_char.is_alphanumeric() || illegal_char == '-')
&& illegal_char.to_string().len() == 1
{
let target_project_name =
format!("{project_name_prefix}{illegal_char}{project_name_suffix}");
let sni_data = format!("{target_project_name}.{common_name}");
assert_eq!(
get_project_name(Some(&sni_data), Some(common_name), None),
Err(ClientCredsParseError::ProjectNameContainsIllegalChars(
target_project_name
))
);
}
}
}
#[test]
fn passing_from_project_name_only() {
let target_project_name = "my-project-123";
let common_names = [Some("localtest.me"), None];
for common_name in common_names {
assert_eq!(
get_project_name(None, common_name, Some(target_project_name)),
Ok(target_project_name.to_string())
);
}
}
#[test]
fn throws_project_name_contains_illegal_chars_from_project_name_only() {
let project_name_prefix = "my-project";
let project_name_suffix = "123";
let common_names = [Some("localtest.me"), None];
for common_name in common_names {
for illegal_char_id in 0..256 {
let illegal_char: char = char::from_u32(illegal_char_id).unwrap();
if !(illegal_char.is_alphanumeric() || illegal_char == '-')
&& illegal_char.to_string().len() == 1
{
let target_project_name =
format!("{project_name_prefix}{illegal_char}{project_name_suffix}");
assert_eq!(
get_project_name(None, common_name, Some(&target_project_name)),
Err(ClientCredsParseError::ProjectNameContainsIllegalChars(
target_project_name
))
);
}
}
}
}
#[test]
fn passing_from_sni_data_and_project_name() {
let target_project_name = "my-project-123";
let common_name = "localtest.me";
let sni_data = format!("{target_project_name}.{common_name}");
assert_eq!(
get_project_name(
Some(&sni_data),
Some(common_name),
Some(target_project_name)
),
Ok(target_project_name.to_string())
);
}
#[test]
fn throws_inconsistent_project_name_and_sni() {
let project_name_param = "my-project-123";
let wrong_project_name = "not-my-project-123";
let common_name = "localtest.me";
let sni_data = format!("{wrong_project_name}.{common_name}");
assert_eq!(
get_project_name(Some(&sni_data), Some(common_name), Some(project_name_param)),
Err(ClientCredsParseError::InconsistentProjectNameAndSNI(
wrong_project_name.to_string(),
project_name_param.to_string()
))
);
}
#[test]
fn throws_common_name_not_set() {
let target_project_name = "my-project-123";
let wrong_project_name = "not-my-project-123";
let common_name = "localtest.me";
let sni_datas = [
Some(format!("{wrong_project_name}.{common_name}")),
Some(format!("{target_project_name}.{common_name}")),
];
let project_names = [None, Some(target_project_name)];
for sni_data in sni_datas {
for project_name_param in project_names {
assert_eq!(
get_project_name(sni_data.as_deref(), None, project_name_param),
Err(ClientCredsParseError::CommonNameNotSet)
);
}
}
}
#[test]
fn throws_inconsistent_common_name_and_sni_data() {
let target_project_name = "my-project-123";
let wrong_project_name = "not-my-project-123";
let common_name = "localtest.me";
let wrong_suffix = "wrongtest.me";
assert_eq!(common_name.len(), wrong_suffix.len());
let wrong_common_name = format!("wrong{wrong_suffix}");
let sni_datas = [
Some(format!("{wrong_project_name}.{wrong_common_name}")),
Some(format!("{target_project_name}.{wrong_common_name}")),
];
let project_names = [None, Some(target_project_name)];
for project_name_param in project_names {
for sni_data in &sni_datas {
assert_eq!(
get_project_name(sni_data.as_deref(), Some(common_name), project_name_param),
Err(ClientCredsParseError::InconsistentCommonNameAndSNI(
common_name.to_string(),
sni_data.clone().unwrap().to_string()
))
);
}
}
}
}

View File

@@ -36,23 +36,35 @@ pub struct ProxyConfig {
pub auth_link_uri: ApiUrl,
}
pub type TlsConfig = Arc<rustls::ServerConfig>;
pub struct TlsConfig {
pub config: Arc<rustls::ServerConfig>,
pub common_name: Option<String>,
}
impl TlsConfig {
pub fn to_server_config(&self) -> Arc<rustls::ServerConfig> {
self.config.clone()
}
}
/// Configure TLS for the main endpoint.
pub fn configure_tls(key_path: &str, cert_path: &str) -> anyhow::Result<TlsConfig> {
let key = {
let key_bytes = std::fs::read(key_path).context("TLS key file")?;
let mut keys = rustls_pemfile::pkcs8_private_keys(&mut &key_bytes[..])
.context("couldn't read TLS keys")?;
.context(format!("Failed to read TLS keys at '{key_path}'"))?;
ensure!(keys.len() == 1, "keys.len() = {} (should be 1)", keys.len());
keys.pop().map(rustls::PrivateKey).unwrap()
};
let cert_chain_bytes = std::fs::read(cert_path)
.context(format!("Failed to read TLS cert file at '{cert_path}.'"))?;
let cert_chain = {
let cert_chain_bytes = std::fs::read(cert_path).context("TLS cert file")?;
rustls_pemfile::certs(&mut &cert_chain_bytes[..])
.context("couldn't read TLS certificate chain")?
.context(format!(
"Failed to read TLS certificate chain from bytes from file at '{cert_path}'."
))?
.into_iter()
.map(rustls::Certificate)
.collect()
@@ -64,7 +76,25 @@ pub fn configure_tls(key_path: &str, cert_path: &str) -> anyhow::Result<TlsConfi
// allow TLS 1.2 to be compatible with older client libraries
.with_protocol_versions(&[&rustls::version::TLS13, &rustls::version::TLS12])?
.with_no_client_auth()
.with_single_cert(cert_chain, key)?;
.with_single_cert(cert_chain, key)?
.into();
Ok(config.into())
// determine common name from tls-cert (-c server.crt param).
// used in asserting project name formatting invariant.
let common_name = {
let pem = x509_parser::pem::parse_x509_pem(&cert_chain_bytes)
.context(format!(
"Failed to parse PEM object from bytes from file at '{cert_path}'."
))?
.1;
let almost_common_name = pem.parse_x509()?.tbs_certificate.subject.to_string();
let expected_prefix = "CN=*.";
let common_name = almost_common_name.strip_prefix(expected_prefix);
common_name.map(str::to_string)
};
Ok(TlsConfig {
config,
common_name,
})
}

View File

@@ -81,7 +81,7 @@ async fn handle_client(
NUM_CONNECTIONS_CLOSED_COUNTER.inc();
}
let tls = config.tls_config.clone();
let tls = config.tls_config.as_ref();
let (stream, creds) = match handshake(stream, tls, cancel_map).await? {
Some(x) => x,
None => return Ok(()), // it's a cancellation request
@@ -99,12 +99,14 @@ async fn handle_client(
/// we also take an extra care of propagating only the select handshake errors to client.
async fn handshake<S: AsyncRead + AsyncWrite + Unpin>(
stream: S,
mut tls: Option<TlsConfig>,
mut tls: Option<&TlsConfig>,
cancel_map: &CancelMap,
) -> anyhow::Result<Option<(PqStream<Stream<S>>, auth::ClientCredentials)>> {
// Client may try upgrading to each protocol only once
let (mut tried_ssl, mut tried_gss) = (false, false);
let common_name = tls.and_then(|cfg| cfg.common_name.as_deref());
let mut stream = PqStream::new(Stream::from_raw(stream));
loop {
let msg = stream.read_startup_packet().await?;
@@ -122,7 +124,9 @@ async fn handshake<S: AsyncRead + AsyncWrite + Unpin>(
if let Some(tls) = tls.take() {
// Upgrade raw stream into a secure TLS-backed stream.
// NOTE: We've consumed `tls`; this fact will be used later.
stream = PqStream::new(stream.into_inner().upgrade(tls).await?);
stream = PqStream::new(
stream.into_inner().upgrade(tls.to_server_config()).await?,
);
}
}
_ => bail!(ERR_PROTO_VIOLATION),
@@ -143,15 +147,16 @@ async fn handshake<S: AsyncRead + AsyncWrite + Unpin>(
stream.throw_error_str(ERR_INSECURE_CONNECTION).await?;
}
// Here and forth: `or_else` demands that we use a future here
let mut creds: auth::ClientCredentials = async { params.try_into() }
.or_else(|e| stream.throw_error(e))
.await?;
// Get SNI info when available
let sni_data = match stream.get_ref() {
Stream::Tls { tls } => tls.get_ref().1.sni_hostname().map(|s| s.to_owned()),
_ => None,
};
// Set SNI info when available
if let Stream::Tls { tls } = stream.get_ref() {
creds.sni_data = tls.get_ref().1.sni_hostname().map(|s| s.to_owned());
}
// Construct credentials
let creds =
auth::ClientCredentials::parse(params, sni_data.as_deref(), common_name);
let creds = async { creds }.or_else(|e| stream.throw_error(e)).await?;
break Ok(Some((stream, creds)));
}
@@ -264,12 +269,13 @@ mod tests {
}
/// Generate TLS certificates and build rustls configs for client and server.
fn generate_tls_config(
hostname: &str,
) -> anyhow::Result<(ClientConfig<'_>, Arc<rustls::ServerConfig>)> {
fn generate_tls_config<'a>(
hostname: &'a str,
common_name: &'a str,
) -> anyhow::Result<(ClientConfig<'a>, TlsConfig)> {
let (ca, cert, key) = generate_certs(hostname)?;
let server_config = {
let tls_config = {
let config = rustls::ServerConfig::builder()
.with_safe_defaults()
.with_no_client_auth()
@@ -291,7 +297,12 @@ mod tests {
ClientConfig { config, hostname }
};
Ok((client_config, server_config))
let tls_config = TlsConfig {
config: tls_config,
common_name: Some(common_name.to_string()),
};
Ok((client_config, tls_config))
}
#[async_trait]
@@ -346,7 +357,7 @@ mod tests {
auth: impl TestAuth + Send,
) -> anyhow::Result<()> {
let cancel_map = CancelMap::default();
let (mut stream, _creds) = handshake(client, tls, &cancel_map)
let (mut stream, _creds) = handshake(client, tls.as_ref(), &cancel_map)
.await?
.context("handshake failed")?;
@@ -365,7 +376,8 @@ mod tests {
async fn handshake_tls_is_enforced_by_proxy() -> anyhow::Result<()> {
let (client, server) = tokio::io::duplex(1024);
let (_, server_config) = generate_tls_config("localhost")?;
let (_, server_config) =
generate_tls_config("generic-project-name.localhost", "localhost")?;
let proxy = tokio::spawn(dummy_proxy(client, Some(server_config), NoAuth));
let client_err = tokio_postgres::Config::new()
@@ -393,7 +405,8 @@ mod tests {
async fn handshake_tls() -> anyhow::Result<()> {
let (client, server) = tokio::io::duplex(1024);
let (client_config, server_config) = generate_tls_config("localhost")?;
let (client_config, server_config) =
generate_tls_config("generic-project-name.localhost", "localhost")?;
let proxy = tokio::spawn(dummy_proxy(client, Some(server_config), NoAuth));
let (_client, _conn) = tokio_postgres::Config::new()
@@ -415,6 +428,7 @@ mod tests {
let (_client, _conn) = tokio_postgres::Config::new()
.user("john_doe")
.dbname("earth")
.options("project=generic-project-name")
.ssl_mode(SslMode::Prefer)
.connect_raw(server, NoTls)
.await?;
@@ -476,7 +490,8 @@ mod tests {
async fn scram_auth_good(#[case] password: &str) -> anyhow::Result<()> {
let (client, server) = tokio::io::duplex(1024);
let (client_config, server_config) = generate_tls_config("localhost")?;
let (client_config, server_config) =
generate_tls_config("generic-project-name.localhost", "localhost")?;
let proxy = tokio::spawn(dummy_proxy(
client,
Some(server_config),
@@ -498,7 +513,8 @@ mod tests {
async fn scram_auth_mock() -> anyhow::Result<()> {
let (client, server) = tokio::io::duplex(1024);
let (client_config, server_config) = generate_tls_config("localhost")?;
let (client_config, server_config) =
generate_tls_config("generic-project-name.localhost", "localhost")?;
let proxy = tokio::spawn(dummy_proxy(
client,
Some(server_config),

View File

@@ -5,6 +5,11 @@ use anyhow::Context;
use anyhow::Error;
use anyhow::Result;
use etcd_broker::subscription_value::SkTimelineInfo;
use etcd_broker::LeaseKeepAliveStream;
use etcd_broker::LeaseKeeper;
use std::collections::hash_map::Entry;
use std::collections::HashMap;
use std::time::Duration;
use tokio::spawn;
use tokio::task::JoinHandle;
@@ -21,7 +26,7 @@ use utils::zid::{NodeId, ZTenantTimelineId};
const RETRY_INTERVAL_MSEC: u64 = 1000;
const PUSH_INTERVAL_MSEC: u64 = 1000;
const LEASE_TTL_SEC: i64 = 5;
const LEASE_TTL_SEC: i64 = 10;
pub fn thread_main(conf: SafeKeeperConf) {
let runtime = runtime::Builder::new_current_thread()
@@ -154,13 +159,48 @@ pub fn get_candiate_name(system_id: NodeId) -> String {
format!("id_{system_id}")
}
async fn push_sk_info(
zttid: ZTenantTimelineId,
mut client: Client,
key: String,
sk_info: SkTimelineInfo,
mut lease: Lease,
) -> anyhow::Result<(ZTenantTimelineId, Lease)> {
let put_opts = PutOptions::new().with_lease(lease.id);
client
.put(
key.clone(),
serde_json::to_string(&sk_info)?,
Some(put_opts),
)
.await
.with_context(|| format!("failed to push safekeeper info to {}", key))?;
// revive the lease
lease
.keeper
.keep_alive()
.await
.context("failed to send LeaseKeepAliveRequest")?;
lease
.ka_stream
.message()
.await
.context("failed to receive LeaseKeepAliveResponse")?;
Ok((zttid, lease))
}
struct Lease {
id: i64,
keeper: LeaseKeeper,
ka_stream: LeaseKeepAliveStream,
}
/// Push once in a while data about all active timelines to the broker.
async fn push_loop(conf: SafeKeeperConf) -> anyhow::Result<()> {
let mut client = Client::connect(&conf.broker_endpoints, None).await?;
// Get and maintain lease to automatically delete obsolete data
let lease = client.lease_grant(LEASE_TTL_SEC, None).await?;
let (mut keeper, mut ka_stream) = client.lease_keep_alive(lease.id()).await?;
let mut leases: HashMap<ZTenantTimelineId, Lease> = HashMap::new();
let push_interval = Duration::from_millis(PUSH_INTERVAL_MSEC);
loop {
@@ -168,33 +208,46 @@ async fn push_loop(conf: SafeKeeperConf) -> anyhow::Result<()> {
// is under plain mutex. That's ok, all this code is not performance
// sensitive and there is no risk of deadlock as we don't await while
// lock is held.
for zttid in GlobalTimelines::get_active_timelines() {
if let Some(tli) = GlobalTimelines::get_loaded(zttid) {
let sk_info = tli.get_public_info(&conf)?;
let put_opts = PutOptions::new().with_lease(lease.id());
client
.put(
timeline_safekeeper_path(
conf.broker_etcd_prefix.clone(),
zttid,
conf.my_id,
),
serde_json::to_string(&sk_info)?,
Some(put_opts),
)
.await
.context("failed to push safekeeper info")?;
let active_tlis = GlobalTimelines::get_active_timelines();
// // Get and maintain (if not yet) per timeline lease to automatically delete obsolete data.
for zttid in active_tlis.iter() {
if let Entry::Vacant(v) = leases.entry(*zttid) {
let lease = client.lease_grant(LEASE_TTL_SEC, None).await?;
let (keeper, ka_stream) = client.lease_keep_alive(lease.id()).await?;
v.insert(Lease {
id: lease.id(),
keeper,
ka_stream,
});
}
}
// revive the lease
keeper
.keep_alive()
.await
.context("failed to send LeaseKeepAliveRequest")?;
ka_stream
.message()
.await
.context("failed to receive LeaseKeepAliveResponse")?;
leases.retain(|zttid, _| active_tlis.contains(zttid));
// Push data concurrently to not suffer from latency, with many timelines it can be slow.
let handles = active_tlis
.iter()
.filter_map(|zttid| GlobalTimelines::get_loaded(*zttid))
.map(|tli| {
let sk_info = tli.get_public_info(&conf);
let key = timeline_safekeeper_path(
conf.broker_etcd_prefix.clone(),
tli.zttid,
conf.my_id,
);
let lease = leases.remove(&tli.zttid).unwrap();
tokio::spawn(push_sk_info(tli.zttid, client.clone(), key, sk_info, lease))
})
.collect::<Vec<_>>();
for h in handles {
let (zttid, lease) = h.await??;
// It is ugly to pull leases from hash and then put it back, but
// otherwise we have to resort to long living per tli tasks (which
// would generate a lot of errors when etcd is down) as task wants to
// have 'static objects, we can't borrow to it.
leases.insert(zttid, lease);
}
sleep(push_interval).await;
}
}
@@ -221,15 +274,12 @@ async fn pull_loop(conf: SafeKeeperConf) -> Result<()> {
.await
.context("failed to subscribe for safekeeper info")?;
loop {
match subscription.fetch_data().await {
match subscription.value_updates.recv().await {
Some(new_info) => {
for (zttid, sk_info) in new_info {
// note: there are blocking operations below, but it's considered fine for now
if let Ok(tli) = GlobalTimelines::get(&conf, zttid, false) {
for (safekeeper_id, info) in sk_info {
tli.record_safekeeper_info(&info, safekeeper_id).await?
}
}
// note: there are blocking operations below, but it's considered fine for now
if let Ok(tli) = GlobalTimelines::get(&conf, new_info.key.id, false) {
tli.record_safekeeper_info(&new_info.value, new_info.key.node_id)
.await?
}
}
None => {

View File

@@ -239,6 +239,19 @@ pub fn upgrade_control_file(buf: &[u8], version: u32) -> Result<SafeKeeperState>
remote_consistent_lsn: Lsn(0),
peers: Peers(vec![]),
});
} else if version == 5 {
info!("reading safekeeper control file version {}", version);
let mut oldstate = SafeKeeperState::des(&buf[..buf.len()])?;
if oldstate.timeline_start_lsn != Lsn(0) {
return Ok(oldstate);
}
// set special timeline_start_lsn because we don't know the real one
info!("setting timeline_start_lsn and local_start_lsn to Lsn(1)");
oldstate.timeline_start_lsn = Lsn(1);
oldstate.local_start_lsn = Lsn(1);
return Ok(oldstate);
}
bail!("unsupported safekeeper control file version {}", version)
}

View File

@@ -124,7 +124,7 @@ fn send_proposer_elected(spg: &mut SafekeeperPostgresHandler, term: Term, lsn: L
term,
start_streaming_at: lsn,
term_history: history,
timeline_start_lsn: Lsn(0),
timeline_start_lsn: lsn,
});
spg.timeline.get().process_msg(&proposer_elected_request)?;

View File

@@ -242,9 +242,9 @@ impl Collector for TimelineCollector {
let timeline_id = tli.zttid.timeline_id.to_string();
let labels = &[tenant_id.as_str(), timeline_id.as_str()];
let mut most_advanced: Option<utils::pq_proto::ZenithFeedback> = None;
let mut most_advanced: Option<utils::pq_proto::ReplicationFeedback> = None;
for replica in tli.replicas.iter() {
if let Some(replica_feedback) = replica.zenith_feedback {
if let Some(replica_feedback) = replica.pageserver_feedback {
if let Some(current) = most_advanced {
if current.ps_writelsn < replica_feedback.ps_writelsn {
most_advanced = Some(replica_feedback);

View File

@@ -23,12 +23,12 @@ use postgres_ffi::xlog_utils::MAX_SEND_SIZE;
use utils::{
bin_ser::LeSer,
lsn::Lsn,
pq_proto::{SystemId, ZenithFeedback},
pq_proto::{ReplicationFeedback, SystemId},
zid::{NodeId, ZTenantId, ZTenantTimelineId, ZTimelineId},
};
pub const SK_MAGIC: u32 = 0xcafeceefu32;
pub const SK_FORMAT_VERSION: u32 = 5;
pub const SK_FORMAT_VERSION: u32 = 6;
const SK_PROTOCOL_VERSION: u32 = 2;
const UNKNOWN_SERVER_VERSION: u32 = 0;
@@ -348,7 +348,7 @@ pub struct AppendResponse {
// a criterion for walproposer --sync mode exit
pub commit_lsn: Lsn,
pub hs_feedback: HotStandbyFeedback,
pub zenith_feedback: ZenithFeedback,
pub pageserver_feedback: ReplicationFeedback,
}
impl AppendResponse {
@@ -358,7 +358,7 @@ impl AppendResponse {
flush_lsn: Lsn(0),
commit_lsn: Lsn(0),
hs_feedback: HotStandbyFeedback::empty(),
zenith_feedback: ZenithFeedback::empty(),
pageserver_feedback: ReplicationFeedback::empty(),
}
}
}
@@ -476,7 +476,7 @@ impl AcceptorProposerMessage {
buf.put_u64_le(msg.hs_feedback.xmin);
buf.put_u64_le(msg.hs_feedback.catalog_xmin);
msg.zenith_feedback.serialize(buf)?
msg.pageserver_feedback.serialize(buf)?
}
}
@@ -677,7 +677,7 @@ where
commit_lsn: self.state.commit_lsn,
// will be filled by the upper code to avoid bothering safekeeper
hs_feedback: HotStandbyFeedback::empty(),
zenith_feedback: ZenithFeedback::empty(),
pageserver_feedback: ReplicationFeedback::empty(),
};
trace!("formed AppendResponse {:?}", ar);
ar

View File

@@ -13,15 +13,17 @@ use serde::{Deserialize, Serialize};
use std::cmp::min;
use std::net::Shutdown;
use std::sync::Arc;
use std::thread::sleep;
use std::time::Duration;
use std::{str, thread};
use tokio::sync::watch::Receiver;
use tokio::time::timeout;
use tracing::*;
use utils::{
bin_ser::BeSer,
lsn::Lsn,
postgres_backend::PostgresBackend,
pq_proto::{BeMessage, FeMessage, WalSndKeepAlive, XLogDataBody, ZenithFeedback},
pq_proto::{BeMessage, FeMessage, ReplicationFeedback, WalSndKeepAlive, XLogDataBody},
sock_split::ReadStream,
};
@@ -29,7 +31,7 @@ use utils::{
const HOT_STANDBY_FEEDBACK_TAG_BYTE: u8 = b'h';
const STANDBY_STATUS_UPDATE_TAG_BYTE: u8 = b'r';
// zenith extension of replication protocol
const ZENITH_STATUS_UPDATE_TAG_BYTE: u8 = b'z';
const NEON_STATUS_UPDATE_TAG_BYTE: u8 = b'z';
type FullTransactionId = u64;
@@ -122,15 +124,15 @@ impl ReplicationConn {
warn!("unexpected StandbyReply. Read-only postgres replicas are not supported in safekeepers yet.");
// timeline.update_replica_state(replica_id, Some(state));
}
Some(ZENITH_STATUS_UPDATE_TAG_BYTE) => {
Some(NEON_STATUS_UPDATE_TAG_BYTE) => {
// Note: deserializing is on m[9..] because we skip the tag byte and len bytes.
let buf = Bytes::copy_from_slice(&m[9..]);
let reply = ZenithFeedback::parse(buf);
let reply = ReplicationFeedback::parse(buf);
trace!("ZenithFeedback is {:?}", reply);
// Only pageserver sends ZenithFeedback, so set the flag.
trace!("ReplicationFeedback is {:?}", reply);
// Only pageserver sends ReplicationFeedback, so set the flag.
// This replica is the source of information to resend to compute.
state.zenith_feedback = Some(reply);
state.pageserver_feedback = Some(reply);
timeline.update_replica_state(replica_id, state);
}
@@ -191,100 +193,142 @@ impl ReplicationConn {
}
})?;
let mut wal_seg_size: usize;
loop {
wal_seg_size = spg.timeline.get().get_state().1.server.wal_seg_size as usize;
if wal_seg_size == 0 {
error!("Cannot start replication before connecting to wal_proposer");
sleep(Duration::from_secs(1));
let runtime = tokio::runtime::Builder::new_current_thread()
.enable_all()
.build()?;
runtime.block_on(async move {
let (_, persisted_state) = spg.timeline.get().get_state();
// add persisted_state.timeline_start_lsn == Lsn(0) check
if persisted_state.server.wal_seg_size == 0 {
bail!("Cannot start replication before connecting to walproposer");
}
let wal_end = spg.timeline.get().get_end_of_wal();
// Walproposer gets special handling: safekeeper must give proposer all
// local WAL till the end, whether committed or not (walproposer will
// hang otherwise). That's because walproposer runs the consensus and
// synchronizes safekeepers on the most advanced one.
//
// There is a small risk of this WAL getting concurrently garbaged if
// another compute rises which collects majority and starts fixing log
// on this safekeeper itself. That's ok as (old) proposer will never be
// able to commit such WAL.
let stop_pos: Option<Lsn> = if spg.appname == Some("wal_proposer_recovery".to_string())
{
Some(wal_end)
} else {
None
};
info!("Start replication from {:?} till {:?}", start_pos, stop_pos);
// switch to copy
pgb.write_message(&BeMessage::CopyBothResponse)?;
let mut end_pos = Lsn(0);
let mut wal_reader = WalReader::new(
spg.conf.timeline_dir(&spg.timeline.get().zttid),
&persisted_state,
start_pos,
spg.conf.wal_backup_enabled,
)?;
// buffer for wal sending, limited by MAX_SEND_SIZE
let mut send_buf = vec![0u8; MAX_SEND_SIZE];
// watcher for commit_lsn updates
let mut commit_lsn_watch_rx = spg.timeline.get().get_commit_lsn_watch_rx();
loop {
if let Some(stop_pos) = stop_pos {
if start_pos >= stop_pos {
break; /* recovery finished */
}
end_pos = stop_pos;
} else {
/* Wait until we have some data to stream */
let lsn = wait_for_lsn(&mut commit_lsn_watch_rx, start_pos).await?;
if let Some(lsn) = lsn {
end_pos = lsn;
} else {
// TODO: also check once in a while whether we are walsender
// to right pageserver.
if spg.timeline.get().stop_walsender(replica_id)? {
// Shut down, timeline is suspended.
// TODO create proper error type for this
bail!("end streaming to {:?}", spg.appname);
}
// timeout expired: request pageserver status
pgb.write_message(&BeMessage::KeepAlive(WalSndKeepAlive {
sent_ptr: end_pos.0,
timestamp: get_current_timestamp(),
request_reply: true,
}))
.context("Failed to send KeepAlive message")?;
continue;
}
}
let send_size = end_pos.checked_sub(start_pos).unwrap().0 as usize;
let send_size = min(send_size, send_buf.len());
let send_buf = &mut send_buf[..send_size];
// read wal into buffer
let send_size = wal_reader.read(send_buf).await?;
let send_buf = &send_buf[..send_size];
// Write some data to the network socket.
pgb.write_message(&BeMessage::XLogData(XLogDataBody {
wal_start: start_pos.0,
wal_end: end_pos.0,
timestamp: get_current_timestamp(),
data: send_buf,
}))
.context("Failed to send XLogData")?;
start_pos += send_size as u64;
trace!("sent WAL up to {}", start_pos);
}
Ok(())
})
}
}
const POLL_STATE_TIMEOUT: Duration = Duration::from_secs(1);
// Wait until we have commit_lsn > lsn or timeout expires. Returns latest commit_lsn.
async fn wait_for_lsn(rx: &mut Receiver<Lsn>, lsn: Lsn) -> Result<Option<Lsn>> {
let commit_lsn: Lsn = *rx.borrow();
if commit_lsn > lsn {
return Ok(Some(commit_lsn));
}
let res = timeout(POLL_STATE_TIMEOUT, async move {
let mut commit_lsn;
loop {
rx.changed().await?;
commit_lsn = *rx.borrow();
if commit_lsn > lsn {
break;
}
}
let wal_end = spg.timeline.get().get_end_of_wal();
// Walproposer gets special handling: safekeeper must give proposer all
// local WAL till the end, whether committed or not (walproposer will
// hang otherwise). That's because walproposer runs the consensus and
// synchronizes safekeepers on the most advanced one.
//
// There is a small risk of this WAL getting concurrently garbaged if
// another compute rises which collects majority and starts fixing log
// on this safekeeper itself. That's ok as (old) proposer will never be
// able to commit such WAL.
let stop_pos: Option<Lsn> = if spg.appname == Some("wal_proposer_recovery".to_string()) {
Some(wal_end)
} else {
None
};
info!("Start replication from {:?} till {:?}", start_pos, stop_pos);
// switch to copy
pgb.write_message(&BeMessage::CopyBothResponse)?;
Ok(commit_lsn)
})
.await;
let mut end_pos = Lsn(0);
let mut wal_reader = WalReader::new(
spg.conf.timeline_dir(&spg.timeline.get().zttid),
wal_seg_size,
start_pos,
);
// buffer for wal sending, limited by MAX_SEND_SIZE
let mut send_buf = vec![0u8; MAX_SEND_SIZE];
loop {
if let Some(stop_pos) = stop_pos {
if start_pos >= stop_pos {
break; /* recovery finished */
}
end_pos = stop_pos;
} else {
/* Wait until we have some data to stream */
let lsn = spg.timeline.get().wait_for_lsn(start_pos);
if let Some(lsn) = lsn {
end_pos = lsn;
} else {
// TODO: also check once in a while whether we are walsender
// to right pageserver.
if spg.timeline.get().stop_walsender(replica_id)? {
// Shut down, timeline is suspended.
// TODO create proper error type for this
bail!("end streaming to {:?}", spg.appname);
}
// timeout expired: request pageserver status
pgb.write_message(&BeMessage::KeepAlive(WalSndKeepAlive {
sent_ptr: end_pos.0,
timestamp: get_current_timestamp(),
request_reply: true,
}))
.context("Failed to send KeepAlive message")?;
continue;
}
}
let send_size = end_pos.checked_sub(start_pos).unwrap().0 as usize;
let send_size = min(send_size, send_buf.len());
let send_buf = &mut send_buf[..send_size];
// read wal into buffer
let send_size = wal_reader.read(send_buf)?;
let send_buf = &send_buf[..send_size];
// Write some data to the network socket.
pgb.write_message(&BeMessage::XLogData(XLogDataBody {
wal_start: start_pos.0,
wal_end: end_pos.0,
timestamp: get_current_timestamp(),
data: send_buf,
}))
.context("Failed to send XLogData")?;
start_pos += send_size as u64;
trace!("sent WAL up to {}", start_pos);
}
Ok(())
match res {
// success
Ok(Ok(commit_lsn)) => Ok(Some(commit_lsn)),
// error inside closure
Ok(Err(err)) => Err(err),
// timeout
Err(_) => Ok(None),
}
}

View File

@@ -11,17 +11,17 @@ use serde::Serialize;
use tokio::sync::watch;
use std::cmp::{max, min};
use std::collections::HashMap;
use std::collections::{HashMap, HashSet};
use std::fs::{self};
use std::sync::{Arc, Condvar, Mutex, MutexGuard};
use std::time::Duration;
use std::sync::{Arc, Mutex, MutexGuard};
use tokio::sync::mpsc::Sender;
use tracing::*;
use utils::{
lsn::Lsn,
pq_proto::ZenithFeedback,
pq_proto::ReplicationFeedback,
zid::{NodeId, ZTenantId, ZTenantTimelineId},
};
@@ -37,8 +37,6 @@ use crate::wal_storage;
use crate::wal_storage::Storage as wal_storage_iface;
use crate::SafeKeeperConf;
const POLL_STATE_TIMEOUT: Duration = Duration::from_secs(1);
/// Replica status update + hot standby feedback
#[derive(Debug, Clone, Copy)]
pub struct ReplicaState {
@@ -48,8 +46,8 @@ pub struct ReplicaState {
pub remote_consistent_lsn: Lsn,
/// combined hot standby feedback from all replicas
pub hs_feedback: HotStandbyFeedback,
/// Zenith specific feedback received from pageserver, if any
pub zenith_feedback: Option<ZenithFeedback>,
/// Replication specific feedback received from pageserver, if any
pub pageserver_feedback: Option<ReplicationFeedback>,
}
impl Default for ReplicaState {
@@ -68,7 +66,7 @@ impl ReplicaState {
xmin: u64::MAX,
catalog_xmin: u64::MAX,
},
zenith_feedback: None,
pageserver_feedback: None,
}
}
}
@@ -77,9 +75,6 @@ impl ReplicaState {
struct SharedState {
/// Safekeeper object
sk: SafeKeeper<control_file::FileStorage, wal_storage::PhysicalStorage>,
/// For receiving-sending wal cooperation
/// quorum commit LSN we've notified walsenders about
notified_commit_lsn: Lsn,
/// State of replicas
replicas: Vec<Option<ReplicaState>>,
/// True when WAL backup launcher oversees the timeline, making sure WAL is
@@ -112,7 +107,6 @@ impl SharedState {
let sk = SafeKeeper::new(zttid.timeline_id, control_store, wal_store, conf.my_id)?;
Ok(Self {
notified_commit_lsn: Lsn(0),
sk,
replicas: Vec::new(),
wal_backup_active: false,
@@ -131,7 +125,6 @@ impl SharedState {
info!("timeline {} restored", zttid.timeline_id);
Ok(Self {
notified_commit_lsn: Lsn(0),
sk: SafeKeeper::new(zttid.timeline_id, control_store, wal_store, conf.my_id)?,
replicas: Vec::new(),
wal_backup_active: false,
@@ -221,25 +214,25 @@ impl SharedState {
// we need to know which pageserver compute node considers to be main.
// See https://github.com/zenithdb/zenith/issues/1171
//
if let Some(zenith_feedback) = state.zenith_feedback {
if let Some(acc_feedback) = acc.zenith_feedback {
if acc_feedback.ps_writelsn < zenith_feedback.ps_writelsn {
if let Some(pageserver_feedback) = state.pageserver_feedback {
if let Some(acc_feedback) = acc.pageserver_feedback {
if acc_feedback.ps_writelsn < pageserver_feedback.ps_writelsn {
warn!("More than one pageserver is streaming WAL for the timeline. Feedback resolving is not fully supported yet.");
acc.zenith_feedback = Some(zenith_feedback);
acc.pageserver_feedback = Some(pageserver_feedback);
}
} else {
acc.zenith_feedback = Some(zenith_feedback);
acc.pageserver_feedback = Some(pageserver_feedback);
}
// last lsn received by pageserver
// FIXME if multiple pageservers are streaming WAL, last_received_lsn must be tracked per pageserver.
// See https://github.com/zenithdb/zenith/issues/1171
acc.last_received_lsn = Lsn::from(zenith_feedback.ps_writelsn);
acc.last_received_lsn = Lsn::from(pageserver_feedback.ps_writelsn);
// When at least one pageserver has preserved data up to remote_consistent_lsn,
// safekeeper is free to delete it, so choose max of all pageservers.
acc.remote_consistent_lsn = max(
Lsn::from(zenith_feedback.ps_applylsn),
Lsn::from(pageserver_feedback.ps_applylsn),
acc.remote_consistent_lsn,
);
}
@@ -271,8 +264,6 @@ pub struct Timeline {
/// For breeding receivers.
commit_lsn_watch_rx: watch::Receiver<Lsn>,
mutex: Mutex<SharedState>,
/// conditional variable used to notify wal senders
cond: Condvar,
}
impl Timeline {
@@ -289,7 +280,6 @@ impl Timeline {
commit_lsn_watch_tx,
commit_lsn_watch_rx,
mutex: Mutex::new(shared_state),
cond: Condvar::new(),
}
}
@@ -333,7 +323,7 @@ impl Timeline {
let mut shared_state = self.mutex.lock().unwrap();
if shared_state.num_computes == 0 {
let replica_state = shared_state.replicas[replica_id].unwrap();
let stop = shared_state.notified_commit_lsn == Lsn(0) || // no data at all yet
let stop = shared_state.sk.inmem.commit_lsn == Lsn(0) || // no data at all yet
(replica_state.remote_consistent_lsn != Lsn::MAX && // Lsn::MAX means that we don't know the latest LSN yet.
replica_state.remote_consistent_lsn >= shared_state.sk.inmem.commit_lsn);
if stop {
@@ -405,39 +395,6 @@ impl Timeline {
})
}
/// Timed wait for an LSN to be committed.
///
/// Returns the last committed LSN, which will be at least
/// as high as the LSN waited for, or None if timeout expired.
///
pub fn wait_for_lsn(&self, lsn: Lsn) -> Option<Lsn> {
let mut shared_state = self.mutex.lock().unwrap();
loop {
let commit_lsn = shared_state.notified_commit_lsn;
// This must be `>`, not `>=`.
if commit_lsn > lsn {
return Some(commit_lsn);
}
let result = self
.cond
.wait_timeout(shared_state, POLL_STATE_TIMEOUT)
.unwrap();
if result.1.timed_out() {
return None;
}
shared_state = result.0
}
}
// Notify caught-up WAL senders about new WAL data received
// TODO: replace-unify it with commit_lsn_watch.
fn notify_wal_senders(&self, shared_state: &mut MutexGuard<SharedState>) {
if shared_state.notified_commit_lsn < shared_state.sk.inmem.commit_lsn {
shared_state.notified_commit_lsn = shared_state.sk.inmem.commit_lsn;
self.cond.notify_all();
}
}
pub fn get_commit_lsn_watch_rx(&self) -> watch::Receiver<Lsn> {
self.commit_lsn_watch_rx.clone()
}
@@ -457,13 +414,11 @@ impl Timeline {
if let Some(AcceptorProposerMessage::AppendResponse(ref mut resp)) = rmsg {
let state = shared_state.get_replicas_state();
resp.hs_feedback = state.hs_feedback;
if let Some(zenith_feedback) = state.zenith_feedback {
resp.zenith_feedback = zenith_feedback;
if let Some(pageserver_feedback) = state.pageserver_feedback {
resp.pageserver_feedback = pageserver_feedback;
}
}
// Ping wal sender that new data might be available.
self.notify_wal_senders(&mut shared_state);
commit_lsn = shared_state.sk.inmem.commit_lsn;
}
self.commit_lsn_watch_tx.send(commit_lsn)?;
@@ -490,9 +445,9 @@ impl Timeline {
}
/// Prepare public safekeeper info for reporting.
pub fn get_public_info(&self, conf: &SafeKeeperConf) -> anyhow::Result<SkTimelineInfo> {
pub fn get_public_info(&self, conf: &SafeKeeperConf) -> SkTimelineInfo {
let shared_state = self.mutex.lock().unwrap();
Ok(SkTimelineInfo {
SkTimelineInfo {
last_log_term: Some(shared_state.sk.get_epoch()),
flush_lsn: Some(shared_state.sk.wal_store.flush_lsn()),
// note: this value is not flushed to control file yet and can be lost
@@ -505,7 +460,7 @@ impl Timeline {
peer_horizon_lsn: Some(shared_state.sk.inmem.peer_horizon_lsn),
safekeeper_connstr: Some(conf.listen_pg_addr.clone()),
backup_lsn: Some(shared_state.sk.inmem.backup_lsn),
})
}
}
/// Update timeline state with peer safekeeper data.
@@ -524,7 +479,6 @@ impl Timeline {
return Ok(());
}
shared_state.sk.record_safekeeper_info(sk_info)?;
self.notify_wal_senders(&mut shared_state);
is_wal_backup_action_pending = shared_state.update_status(self.zttid);
commit_lsn = shared_state.sk.inmem.commit_lsn;
}
@@ -671,6 +625,8 @@ impl GlobalTimelines {
zttid: ZTenantTimelineId,
create: bool,
) -> Result<Arc<Timeline>> {
let _enter = info_span!("", timeline = %zttid.tenant_id).entered();
let mut state = TIMELINES_STATE.lock().unwrap();
match state.timelines.get(&zttid) {
@@ -713,7 +669,7 @@ impl GlobalTimelines {
}
/// Get ZTenantTimelineIDs of all active timelines.
pub fn get_active_timelines() -> Vec<ZTenantTimelineId> {
pub fn get_active_timelines() -> HashSet<ZTenantTimelineId> {
let state = TIMELINES_STATE.lock().unwrap();
state
.timelines

View File

@@ -2,6 +2,7 @@ use anyhow::{Context, Result};
use etcd_broker::subscription_key::{
NodeKind, OperationKind, SkOperationKind, SubscriptionKey, SubscriptionKind,
};
use tokio::io::AsyncRead;
use tokio::task::JoinHandle;
use std::cmp::min;
@@ -10,7 +11,9 @@ use std::path::{Path, PathBuf};
use std::sync::Arc;
use std::time::Duration;
use postgres_ffi::xlog_utils::{XLogFileName, XLogSegNo, XLogSegNoOffsetToRecPtr, PG_TLI};
use postgres_ffi::xlog_utils::{
XLogFileName, XLogSegNo, XLogSegNoOffsetToRecPtr, MAX_SEND_SIZE, PG_TLI,
};
use remote_storage::{GenericRemoteStorage, RemoteStorage};
use tokio::fs::File;
use tokio::runtime::Builder;
@@ -445,3 +448,49 @@ async fn backup_object(source_file: &Path, size: usize) -> Result<()> {
Ok(())
}
pub async fn read_object(
file_path: PathBuf,
offset: u64,
) -> (impl AsyncRead, JoinHandle<Result<()>>) {
let storage = REMOTE_STORAGE.get().expect("failed to get remote storage");
let (mut pipe_writer, pipe_reader) = tokio::io::duplex(MAX_SEND_SIZE);
let copy_result = tokio::spawn(async move {
let res = match storage.as_ref().unwrap() {
GenericRemoteStorage::Local(local_storage) => {
let source = local_storage.remote_object_id(&file_path)?;
info!(
"local download about to start from {} at offset {}",
source.display(),
offset
);
local_storage
.download_byte_range(&source, offset, None, &mut pipe_writer)
.await
}
GenericRemoteStorage::S3(s3_storage) => {
let s3key = s3_storage.remote_object_id(&file_path)?;
info!(
"S3 download about to start from {:?} at offset {}",
s3key, offset
);
s3_storage
.download_byte_range(&s3key, offset, None, &mut pipe_writer)
.await
}
};
if let Err(e) = res {
error!("failed to download WAL segment from remote storage: {}", e);
Err(e)
} else {
Ok(())
}
});
(pipe_reader, copy_result)
}

View File

@@ -8,7 +8,9 @@
//! Note that last file has `.partial` suffix, that's different from postgres.
use anyhow::{anyhow, bail, Context, Result};
use std::io::{Read, Seek, SeekFrom};
use std::io::{self, Seek, SeekFrom};
use std::pin::Pin;
use tokio::io::AsyncRead;
use lazy_static::lazy_static;
use postgres_ffi::xlog_utils::{
@@ -26,6 +28,7 @@ use utils::{lsn::Lsn, zid::ZTenantTimelineId};
use crate::safekeeper::SafeKeeperState;
use crate::wal_backup::read_object;
use crate::SafeKeeperConf;
use postgres_ffi::xlog_utils::{XLogFileName, XLOG_BLCKSZ};
@@ -33,6 +36,8 @@ use postgres_ffi::waldecoder::WalStreamDecoder;
use metrics::{register_histogram_vec, Histogram, HistogramVec, DISK_WRITE_SECONDS_BUCKETS};
use tokio::io::{AsyncReadExt, AsyncSeekExt};
lazy_static! {
// The prometheus crate does not support u64 yet, i64 only (see `IntGauge`).
// i64 is faster than f64, so update to u64 when available.
@@ -504,69 +509,123 @@ pub struct WalReader {
timeline_dir: PathBuf,
wal_seg_size: usize,
pos: Lsn,
file: Option<File>,
wal_segment: Option<Pin<Box<dyn AsyncRead>>>,
enable_remote_read: bool,
// S3 will be used to read WAL if LSN is not available locally
local_start_lsn: Lsn,
}
impl WalReader {
pub fn new(timeline_dir: PathBuf, wal_seg_size: usize, pos: Lsn) -> Self {
Self {
timeline_dir,
wal_seg_size,
pos,
file: None,
pub fn new(
timeline_dir: PathBuf,
state: &SafeKeeperState,
start_pos: Lsn,
enable_remote_read: bool,
) -> Result<Self> {
if start_pos < state.timeline_start_lsn {
bail!(
"Requested streaming from {}, which is before the start of the timeline {}",
start_pos,
state.timeline_start_lsn
);
}
// TODO: add state.timeline_start_lsn == Lsn(0) check
if state.server.wal_seg_size == 0 || state.local_start_lsn == Lsn(0) {
bail!("state uninitialized, no data to read");
}
Ok(Self {
timeline_dir,
wal_seg_size: state.server.wal_seg_size as usize,
pos: start_pos,
wal_segment: None,
enable_remote_read,
local_start_lsn: state.local_start_lsn,
})
}
pub fn read(&mut self, buf: &mut [u8]) -> Result<usize> {
// Take the `File` from `wal_file`, or open a new file.
let mut file = match self.file.take() {
Some(file) => file,
None => {
// Open a new file.
let segno = self.pos.segment_number(self.wal_seg_size);
let wal_file_name = XLogFileName(PG_TLI, segno, self.wal_seg_size);
let wal_file_path = self.timeline_dir.join(wal_file_name);
Self::open_wal_file(&wal_file_path)?
}
pub async fn read(&mut self, buf: &mut [u8]) -> Result<usize> {
let mut wal_segment = match self.wal_segment.take() {
Some(reader) => reader,
None => self.open_segment().await?,
};
let xlogoff = self.pos.segment_offset(self.wal_seg_size) as usize;
// How much to read and send in message? We cannot cross the WAL file
// boundary, and we don't want send more than provided buffer.
let xlogoff = self.pos.segment_offset(self.wal_seg_size) as usize;
let send_size = min(buf.len(), self.wal_seg_size - xlogoff);
// Read some data from the file.
let buf = &mut buf[0..send_size];
file.seek(SeekFrom::Start(xlogoff as u64))
.and_then(|_| file.read_exact(buf))
.context("Failed to read data from WAL file")?;
let send_size = wal_segment.read_exact(buf).await?;
self.pos += send_size as u64;
// Decide whether to reuse this file. If we don't set wal_file here
// a new file will be opened next time.
// Decide whether to reuse this file. If we don't set wal_segment here
// a new reader will be opened next time.
if self.pos.segment_offset(self.wal_seg_size) != 0 {
self.file = Some(file);
self.wal_segment = Some(wal_segment);
}
Ok(send_size)
}
/// Open WAL segment at the current position of the reader.
async fn open_segment(&self) -> Result<Pin<Box<dyn AsyncRead>>> {
let xlogoff = self.pos.segment_offset(self.wal_seg_size) as usize;
let segno = self.pos.segment_number(self.wal_seg_size);
let wal_file_name = XLogFileName(PG_TLI, segno, self.wal_seg_size);
let wal_file_path = self.timeline_dir.join(wal_file_name);
// Try to open local file, if we may have WAL locally
if self.pos >= self.local_start_lsn {
let res = Self::open_wal_file(&wal_file_path).await;
match res {
Ok(mut file) => {
file.seek(SeekFrom::Start(xlogoff as u64)).await?;
return Ok(Box::pin(file));
}
Err(e) => {
let is_not_found = e.chain().any(|e| {
if let Some(e) = e.downcast_ref::<io::Error>() {
e.kind() == io::ErrorKind::NotFound
} else {
false
}
});
if !is_not_found {
return Err(e);
}
// NotFound is expected, fall through to remote read
}
};
}
// Try to open remote file, if remote reads are enabled
if self.enable_remote_read {
let (reader, _) = read_object(wal_file_path, xlogoff as u64).await;
return Ok(Box::pin(reader));
}
bail!("WAL segment is not found")
}
/// Helper function for opening a wal file.
fn open_wal_file(wal_file_path: &Path) -> Result<File> {
async fn open_wal_file(wal_file_path: &Path) -> Result<tokio::fs::File> {
// First try to open the .partial file.
let mut partial_path = wal_file_path.to_owned();
partial_path.set_extension("partial");
if let Ok(opened_file) = File::open(&partial_path) {
if let Ok(opened_file) = tokio::fs::File::open(&partial_path).await {
return Ok(opened_file);
}
// If that failed, try it without the .partial extension.
File::open(&wal_file_path)
tokio::fs::File::open(&wal_file_path)
.await
.with_context(|| format!("Failed to open WAL file {:?}", wal_file_path))
.map_err(|e| {
error!("{}", e);
warn!("{}", e);
e
})
}

View File

@@ -45,7 +45,7 @@ If you want to run all tests that have the string "bench" in their names:
Useful environment variables:
`ZENITH_BIN`: The directory where zenith binaries can be found.
`NEON_BIN`: The directory where neon binaries can be found.
`POSTGRES_DISTRIB_DIR`: The directory where postgres distribution can be found.
`TEST_OUTPUT`: Set the directory where test state and test output files
should go.

View File

@@ -35,9 +35,14 @@ def test_createdb(neon_simple_env: NeonEnv):
with closing(db.connect(dbname='foodb')) as conn:
with conn.cursor() as cur:
# Check database size in both branches
cur.execute(
'select pg_size_pretty(pg_database_size(%s)), pg_size_pretty(sum(pg_relation_size(oid))) from pg_class where relisshared is false;',
('foodb', ))
cur.execute("""
select pg_size_pretty(pg_database_size('foodb')),
pg_size_pretty(
sum(pg_relation_size(oid, 'main'))
+sum(pg_relation_size(oid, 'vm'))
+sum(pg_relation_size(oid, 'fsm'))
) FROM pg_class where relisshared is false
""")
res = cur.fetchone()
# check that dbsize equals sum of all relation sizes, excluding shared ones
# This is how we define dbsize in neon for now

View File

@@ -0,0 +1,73 @@
import subprocess
from contextlib import closing
import psycopg2.extras
import pytest
from fixtures.log_helper import log
from fixtures.neon_fixtures import NeonEnvBuilder, PgBin, PortDistributor, VanillaPostgres
from fixtures.neon_fixtures import pg_distrib_dir
import os
from fixtures.utils import mkdir_if_needed, subprocess_capture
import shutil
import getpass
import pwd
num_rows = 1000
# Ensure that regular postgres can start from fullbackup
def test_fullbackup(neon_env_builder: NeonEnvBuilder,
pg_bin: PgBin,
port_distributor: PortDistributor):
neon_env_builder.num_safekeepers = 1
env = neon_env_builder.init_start()
env.neon_cli.create_branch('test_fullbackup')
pgmain = env.postgres.create_start('test_fullbackup')
log.info("postgres is running on 'test_fullbackup' branch")
timeline = pgmain.safe_psql("SHOW neon.timeline_id")[0][0]
with closing(pgmain.connect()) as conn:
with conn.cursor() as cur:
# data loading may take a while, so increase statement timeout
cur.execute("SET statement_timeout='300s'")
cur.execute(f'''CREATE TABLE tbl AS SELECT 'long string to consume some space' || g
from generate_series(1,{num_rows}) g''')
cur.execute("CHECKPOINT")
cur.execute('SELECT pg_current_wal_insert_lsn()')
lsn = cur.fetchone()[0]
log.info(f"start_backup_lsn = {lsn}")
# Set LD_LIBRARY_PATH in the env properly, otherwise we may use the wrong libpq.
# PgBin sets it automatically, but here we need to pipe psql output to the tar command.
psql_env = {'LD_LIBRARY_PATH': os.path.join(str(pg_distrib_dir), 'lib')}
# Get and unpack fullbackup from pageserver
restored_dir_path = os.path.join(env.repo_dir, "restored_datadir")
os.mkdir(restored_dir_path, 0o750)
query = f"fullbackup {env.initial_tenant.hex} {timeline} {lsn}"
cmd = ["psql", "--no-psqlrc", env.pageserver.connstr(), "-c", query]
result_basepath = pg_bin.run_capture(cmd, env=psql_env)
tar_output_file = result_basepath + ".stdout"
subprocess_capture(str(env.repo_dir), ["tar", "-xf", tar_output_file, "-C", restored_dir_path])
# HACK
# fullbackup returns neon specific pg_control and first WAL segment
# use resetwal to overwrite it
pg_resetwal_path = os.path.join(pg_bin.pg_bin_path, 'pg_resetwal')
cmd = [pg_resetwal_path, "-D", restored_dir_path]
pg_bin.run_capture(cmd, env=psql_env)
# Restore from the backup and find the data we inserted
port = port_distributor.get_port()
with VanillaPostgres(restored_dir_path, pg_bin, port, init=False) as vanilla_pg:
# TODO make port an optional argument
vanilla_pg.configure([
f"port={port}",
])
vanilla_pg.start()
num_rows_found = vanilla_pg.safe_psql('select count(*) from tbl;', user="cloud_admin")[0][0]
assert num_rows == num_rows_found

View File

@@ -0,0 +1,193 @@
import pytest
from fixtures.neon_fixtures import NeonEnvBuilder, wait_for_upload, wait_for_last_record_lsn
from fixtures.utils import lsn_from_hex, lsn_to_hex
from uuid import UUID, uuid4
import tarfile
import os
import shutil
from pathlib import Path
import json
from fixtures.utils import subprocess_capture
from fixtures.log_helper import log
from contextlib import closing
from fixtures.neon_fixtures import pg_distrib_dir
@pytest.mark.timeout(600)
def test_import_from_vanilla(test_output_dir, pg_bin, vanilla_pg, neon_env_builder):
# Put data in vanilla pg
vanilla_pg.start()
vanilla_pg.safe_psql("create user cloud_admin with password 'postgres' superuser")
vanilla_pg.safe_psql('''create table t as select 'long string to consume some space' || g
from generate_series(1,300000) g''')
assert vanilla_pg.safe_psql('select count(*) from t') == [(300000, )]
# Take basebackup
basebackup_dir = os.path.join(test_output_dir, "basebackup")
base_tar = os.path.join(basebackup_dir, "base.tar")
wal_tar = os.path.join(basebackup_dir, "pg_wal.tar")
os.mkdir(basebackup_dir)
vanilla_pg.safe_psql("CHECKPOINT")
pg_bin.run([
"pg_basebackup",
"-F",
"tar",
"-d",
vanilla_pg.connstr(),
"-D",
basebackup_dir,
])
# Make corrupt base tar with missing pg_control
unpacked_base = os.path.join(basebackup_dir, "unpacked-base")
corrupt_base_tar = os.path.join(unpacked_base, "corrupt-base.tar")
os.mkdir(unpacked_base, 0o750)
subprocess_capture(str(test_output_dir), ["tar", "-xf", base_tar, "-C", unpacked_base])
os.remove(os.path.join(unpacked_base, "global/pg_control"))
subprocess_capture(str(test_output_dir),
["tar", "-cf", "corrupt-base.tar"] + os.listdir(unpacked_base),
cwd=unpacked_base)
# Get start_lsn and end_lsn
with open(os.path.join(basebackup_dir, "backup_manifest")) as f:
manifest = json.load(f)
start_lsn = manifest["WAL-Ranges"][0]["Start-LSN"]
end_lsn = manifest["WAL-Ranges"][0]["End-LSN"]
node_name = "import_from_vanilla"
tenant = uuid4()
timeline = uuid4()
# Set up pageserver for import
neon_env_builder.enable_local_fs_remote_storage()
env = neon_env_builder.init_start()
env.pageserver.http_client().tenant_create(tenant)
def import_tar(base, wal):
env.neon_cli.raw_cli([
"timeline",
"import",
"--tenant-id",
tenant.hex,
"--timeline-id",
timeline.hex,
"--node-name",
node_name,
"--base-lsn",
start_lsn,
"--base-tarfile",
base,
"--end-lsn",
end_lsn,
"--wal-tarfile",
wal,
])
# Importing corrupt backup fails
with pytest.raises(Exception):
import_tar(corrupt_base_tar, wal_tar)
# Clean up
# TODO it should clean itself
client = env.pageserver.http_client()
client.timeline_detach(tenant, timeline)
# Importing correct backup works
import_tar(base_tar, wal_tar)
# Wait for data to land in s3
wait_for_last_record_lsn(client, tenant, timeline, lsn_from_hex(end_lsn))
wait_for_upload(client, tenant, timeline, lsn_from_hex(end_lsn))
# Check it worked
pg = env.postgres.create_start(node_name, tenant_id=tenant)
assert pg.safe_psql('select count(*) from t') == [(300000, )]
@pytest.mark.timeout(600)
def test_import_from_pageserver(test_output_dir, pg_bin, vanilla_pg, neon_env_builder):
num_rows = 3000
neon_env_builder.num_safekeepers = 1
neon_env_builder.enable_local_fs_remote_storage()
env = neon_env_builder.init_start()
env.neon_cli.create_branch('test_import_from_pageserver')
pgmain = env.postgres.create_start('test_import_from_pageserver')
log.info("postgres is running on 'test_import_from_pageserver' branch")
timeline = pgmain.safe_psql("SHOW neon.timeline_id")[0][0]
with closing(pgmain.connect()) as conn:
with conn.cursor() as cur:
# data loading may take a while, so increase statement timeout
cur.execute("SET statement_timeout='300s'")
cur.execute(f'''CREATE TABLE tbl AS SELECT 'long string to consume some space' || g
from generate_series(1,{num_rows}) g''')
cur.execute("CHECKPOINT")
cur.execute('SELECT pg_current_wal_insert_lsn()')
lsn = cur.fetchone()[0]
log.info(f"start_backup_lsn = {lsn}")
# Set LD_LIBRARY_PATH in the env properly, otherwise we may use the wrong libpq.
# PgBin sets it automatically, but here we need to pipe psql output to the tar command.
psql_env = {'LD_LIBRARY_PATH': os.path.join(str(pg_distrib_dir), 'lib')}
# Get a fullbackup from pageserver
query = f"fullbackup { env.initial_tenant.hex} {timeline} {lsn}"
cmd = ["psql", "--no-psqlrc", env.pageserver.connstr(), "-c", query]
result_basepath = pg_bin.run_capture(cmd, env=psql_env)
tar_output_file = result_basepath + ".stdout"
# Stop the first pageserver instance, erase all its data
env.postgres.stop_all()
env.pageserver.stop()
dir_to_clear = Path(env.repo_dir) / 'tenants'
shutil.rmtree(dir_to_clear)
os.mkdir(dir_to_clear)
#start the pageserver again
env.pageserver.start()
# Import using another tenantid, because we use the same pageserver.
# TODO Create another pageserver to maeke test more realistic.
tenant = uuid4()
# Import to pageserver
node_name = "import_from_pageserver"
client = env.pageserver.http_client()
client.tenant_create(tenant)
env.neon_cli.raw_cli([
"timeline",
"import",
"--tenant-id",
tenant.hex,
"--timeline-id",
timeline,
"--node-name",
node_name,
"--base-lsn",
lsn,
"--base-tarfile",
os.path.join(tar_output_file),
])
# Wait for data to land in s3
wait_for_last_record_lsn(client, tenant, UUID(timeline), lsn_from_hex(lsn))
wait_for_upload(client, tenant, UUID(timeline), lsn_from_hex(lsn))
# Check it worked
pg = env.postgres.create_start(node_name, tenant_id=tenant)
assert pg.safe_psql('select count(*) from tbl') == [(num_rows, )]
# Take another fullbackup
query = f"fullbackup { tenant.hex} {timeline} {lsn}"
cmd = ["psql", "--no-psqlrc", env.pageserver.connstr(), "-c", query]
result_basepath = pg_bin.run_capture(cmd, env=psql_env)
new_tar_output_file = result_basepath + ".stdout"
# Check it's the same as the first fullbackup
# TODO pageserver should be checking checksum
assert os.path.getsize(tar_output_file) == os.path.getsize(new_tar_output_file)

View File

@@ -2,7 +2,7 @@ import pytest
def test_proxy_select_1(static_proxy):
static_proxy.safe_psql("select 1;")
static_proxy.safe_psql("select 1;", options="project=generic-project-name")
# Pass extra options to the server.

View File

@@ -1,5 +1,5 @@
# It's possible to run any regular test with the local fs remote storage via
# env ZENITH_PAGESERVER_OVERRIDES="remote_storage={local_path='/tmp/zenith_zzz/'}" poetry ......
# env ZENITH_PAGESERVER_OVERRIDES="remote_storage={local_path='/tmp/neon_zzz/'}" poetry ......
import shutil, os
from contextlib import closing

View File

@@ -8,7 +8,6 @@ import time
def test_timeline_size(neon_simple_env: NeonEnv):
env = neon_simple_env
# Branch at the point where only 100 rows were inserted
new_timeline_id = env.neon_cli.create_branch('test_timeline_size', 'empty')
client = env.pageserver.http_client()
@@ -23,7 +22,6 @@ def test_timeline_size(neon_simple_env: NeonEnv):
with conn.cursor() as cur:
cur.execute("SHOW neon.timeline_id")
# Create table, and insert the first 100 rows
cur.execute("CREATE TABLE foo (t text)")
cur.execute("""
INSERT INTO foo
@@ -43,6 +41,51 @@ def test_timeline_size(neon_simple_env: NeonEnv):
"current_logical_size_non_incremental"]
def test_timeline_size_createdropdb(neon_simple_env: NeonEnv):
env = neon_simple_env
new_timeline_id = env.neon_cli.create_branch('test_timeline_size', 'empty')
client = env.pageserver.http_client()
timeline_details = assert_local(client, env.initial_tenant, new_timeline_id)
assert timeline_details['local']['current_logical_size'] == timeline_details['local'][
'current_logical_size_non_incremental']
pgmain = env.postgres.create_start("test_timeline_size")
log.info("postgres is running on 'test_timeline_size' branch")
with closing(pgmain.connect()) as conn:
with conn.cursor() as cur:
cur.execute("SHOW neon.timeline_id")
res = assert_local(client, env.initial_tenant, new_timeline_id)
local_details = res['local']
assert local_details["current_logical_size"] == local_details[
"current_logical_size_non_incremental"]
cur.execute('CREATE DATABASE foodb')
with closing(pgmain.connect(dbname='foodb')) as conn:
with conn.cursor() as cur2:
cur2.execute("CREATE TABLE foo (t text)")
cur2.execute("""
INSERT INTO foo
SELECT 'long string to consume some space' || g
FROM generate_series(1, 10) g
""")
res = assert_local(client, env.initial_tenant, new_timeline_id)
local_details = res['local']
assert local_details["current_logical_size"] == local_details[
"current_logical_size_non_incremental"]
cur.execute('DROP DATABASE foodb')
res = assert_local(client, env.initial_tenant, new_timeline_id)
local_details = res['local']
assert local_details["current_logical_size"] == local_details[
"current_logical_size_non_incremental"]
# wait until received_lsn_lag is 0
def wait_for_pageserver_catchup(pgmain: Postgres, polling_interval=1, timeout=60):
started_at = time.time()

View File

@@ -2,6 +2,7 @@ import pytest
import random
import time
import os
import shutil
import signal
import subprocess
import sys
@@ -353,7 +354,7 @@ def test_broker(neon_env_builder: NeonEnvBuilder):
@pytest.mark.parametrize('auth_enabled', [False, True])
def test_wal_removal(neon_env_builder: NeonEnvBuilder, auth_enabled: bool):
neon_env_builder.num_safekeepers = 2
# to advance remote_consistent_llsn
# to advance remote_consistent_lsn
neon_env_builder.enable_local_fs_remote_storage()
neon_env_builder.auth_enabled = auth_enabled
env = neon_env_builder.init_start()
@@ -437,6 +438,26 @@ def wait_segment_offload(tenant_id, timeline_id, live_sk, seg_end):
time.sleep(0.5)
def wait_wal_trim(tenant_id, timeline_id, sk, target_size):
started_at = time.time()
http_cli = sk.http_client()
while True:
tli_status = http_cli.timeline_status(tenant_id, timeline_id)
sk_wal_size = get_dir_size(os.path.join(sk.data_dir(), tenant_id,
timeline_id)) / 1024 / 1024
log.info(f"Safekeeper id={sk.id} wal_size={sk_wal_size:.2f}MB status={tli_status}")
if sk_wal_size <= target_size:
break
elapsed = time.time() - started_at
if elapsed > 20:
raise RuntimeError(
f"timed out waiting {elapsed:.0f}s for sk_id={sk.id} to trim WAL to {target_size:.2f}MB, current size is {sk_wal_size:.2f}MB"
)
time.sleep(0.5)
@pytest.mark.parametrize('storage_type', ['mock_s3', 'local_fs'])
def test_wal_backup(neon_env_builder: NeonEnvBuilder, storage_type: str):
neon_env_builder.num_safekeepers = 3
@@ -485,6 +506,116 @@ def test_wal_backup(neon_env_builder: NeonEnvBuilder, storage_type: str):
wait_segment_offload(tenant_id, timeline_id, env.safekeepers[1], '0/5000000')
@pytest.mark.parametrize('storage_type', ['mock_s3', 'local_fs'])
def test_s3_wal_replay(neon_env_builder: NeonEnvBuilder, storage_type: str):
neon_env_builder.num_safekeepers = 3
if storage_type == 'local_fs':
neon_env_builder.enable_local_fs_remote_storage()
elif storage_type == 'mock_s3':
neon_env_builder.enable_s3_mock_remote_storage('test_s3_wal_replay')
else:
raise RuntimeError(f'Unknown storage type: {storage_type}')
neon_env_builder.remote_storage_users = RemoteStorageUsers.SAFEKEEPER
env = neon_env_builder.init_start()
env.neon_cli.create_branch('test_s3_wal_replay')
env.pageserver.stop()
pageserver_tenants_dir = os.path.join(env.repo_dir, 'tenants')
pageserver_fresh_copy = os.path.join(env.repo_dir, 'tenants_fresh')
log.info(f"Creating a copy of pageserver in a fresh state at {pageserver_fresh_copy}")
shutil.copytree(pageserver_tenants_dir, pageserver_fresh_copy)
env.pageserver.start()
pg = env.postgres.create_start('test_s3_wal_replay')
# learn neon timeline from compute
tenant_id = pg.safe_psql("show neon.tenant_id")[0][0]
timeline_id = pg.safe_psql("show neon.timeline_id")[0][0]
expected_sum = 0
with closing(pg.connect()) as conn:
with conn.cursor() as cur:
cur.execute("create table t(key int, value text)")
cur.execute("insert into t values (1, 'payload')")
expected_sum += 1
offloaded_seg_end = ['0/3000000']
for seg_end in offloaded_seg_end:
# roughly fills two segments
cur.execute("insert into t select generate_series(1,500000), 'payload'")
expected_sum += 500000 * 500001 // 2
cur.execute("select sum(key) from t")
assert cur.fetchone()[0] == expected_sum
for sk in env.safekeepers:
wait_segment_offload(tenant_id, timeline_id, sk, seg_end)
# advance remote_consistent_lsn to trigger WAL trimming
# this LSN should be less than commit_lsn, so timeline will be active=true in safekeepers, to push etcd updates
env.safekeepers[0].http_client().record_safekeeper_info(
tenant_id, timeline_id, {'remote_consistent_lsn': offloaded_seg_end[-1]})
for sk in env.safekeepers:
# require WAL to be trimmed, so no more than one segment is left on disk
wait_wal_trim(tenant_id, timeline_id, sk, 16 * 1.5)
cur.execute('SELECT pg_current_wal_flush_lsn()')
last_lsn = cur.fetchone()[0]
pageserver_lsn = env.pageserver.http_client().timeline_detail(
uuid.UUID(tenant_id), uuid.UUID((timeline_id)))["local"]["last_record_lsn"]
lag = lsn_from_hex(last_lsn) - lsn_from_hex(pageserver_lsn)
log.info(
f'Pageserver last_record_lsn={pageserver_lsn}; flush_lsn={last_lsn}; lag before replay is {lag / 1024}kb'
)
# replace pageserver with a fresh copy
pg.stop_and_destroy()
env.pageserver.stop()
log.info(f'Removing current pageserver state at {pageserver_tenants_dir}')
shutil.rmtree(pageserver_tenants_dir)
log.info(f'Copying fresh pageserver state from {pageserver_fresh_copy}')
shutil.move(pageserver_fresh_copy, pageserver_tenants_dir)
# start pageserver and wait for replay
env.pageserver.start()
wait_lsn_timeout = 60 * 3
started_at = time.time()
last_debug_print = 0.0
while True:
elapsed = time.time() - started_at
if elapsed > wait_lsn_timeout:
raise RuntimeError(f'Timed out waiting for WAL redo')
pageserver_lsn = env.pageserver.http_client().timeline_detail(
uuid.UUID(tenant_id), uuid.UUID((timeline_id)))["local"]["last_record_lsn"]
lag = lsn_from_hex(last_lsn) - lsn_from_hex(pageserver_lsn)
if time.time() > last_debug_print + 10 or lag <= 0:
last_debug_print = time.time()
log.info(f'Pageserver last_record_lsn={pageserver_lsn}; lag is {lag / 1024}kb')
if lag <= 0:
break
time.sleep(1)
log.info(f'WAL redo took {elapsed} s')
# verify data
pg.create_start('test_s3_wal_replay')
with closing(pg.connect()) as conn:
with conn.cursor() as cur:
cur.execute("select sum(key) from t")
assert cur.fetchone()[0] == expected_sum
class ProposerPostgres(PgProtocol):
"""Object for running postgres without NeonEnv"""
def __init__(self,

View File

@@ -29,7 +29,7 @@ from dataclasses import dataclass
# Type-related stuff
from psycopg2.extensions import connection as PgConnection
from psycopg2.extensions import make_dsn, parse_dsn
from typing import Any, Callable, Dict, Iterable, Iterator, List, Optional, TypeVar, cast, Union, Tuple
from typing import Any, Callable, Dict, Iterator, List, Optional, TypeVar, cast, Union, Tuple
from typing_extensions import Literal
import requests
@@ -50,7 +50,7 @@ A fixture is created with the decorator @pytest.fixture decorator.
See docs: https://docs.pytest.org/en/6.2.x/fixture.html
There are several environment variables that can control the running of tests:
ZENITH_BIN, POSTGRES_DISTRIB_DIR, etc. See README.md for more information.
NEON_BIN, POSTGRES_DISTRIB_DIR, etc. See README.md for more information.
There's no need to import this file to use it. It should be declared as a plugin
inside conftest.py, and that makes it available to all tests.
@@ -151,7 +151,7 @@ def pytest_configure(config):
return
# Find the neon binaries.
global neon_binpath
env_neon_bin = os.environ.get('ZENITH_BIN')
env_neon_bin = os.environ.get('NEON_BIN')
if env_neon_bin:
neon_binpath = env_neon_bin
else:
@@ -500,6 +500,8 @@ class NeonEnvBuilder:
num_safekeepers: int = 1,
# Use non-standard SK ids to check for various parsing bugs
safekeepers_id_start: int = 0,
# fsync is disabled by default to make the tests go faster
safekeepers_enable_fsync: bool = False,
auth_enabled: bool = False,
rust_log_override: Optional[str] = None,
default_branch_name=DEFAULT_BRANCH_NAME):
@@ -513,6 +515,7 @@ class NeonEnvBuilder:
self.pageserver_config_override = pageserver_config_override
self.num_safekeepers = num_safekeepers
self.safekeepers_id_start = safekeepers_id_start
self.safekeepers_enable_fsync = safekeepers_enable_fsync
self.auth_enabled = auth_enabled
self.default_branch_name = default_branch_name
self.env: Optional[NeonEnv] = None
@@ -666,7 +669,7 @@ class NeonEnv:
id = {id}
pg_port = {port.pg}
http_port = {port.http}
sync = false # Disable fsyncs to make the tests go faster""")
sync = {'true' if config.safekeepers_enable_fsync else 'false'}""")
if config.auth_enabled:
toml += textwrap.dedent(f"""
auth_enabled = true
@@ -1373,12 +1376,14 @@ def pg_bin(test_output_dir: str) -> PgBin:
class VanillaPostgres(PgProtocol):
def __init__(self, pgdatadir: str, pg_bin: PgBin, port: int):
def __init__(self, pgdatadir: str, pg_bin: PgBin, port: int, init=True):
super().__init__(host='localhost', port=port, dbname='postgres')
self.pgdatadir = pgdatadir
self.pg_bin = pg_bin
self.running = False
self.pg_bin.run_capture(['initdb', '-D', pgdatadir])
if init:
self.pg_bin.run_capture(['initdb', '-D', pgdatadir])
self.configure([f"port = {port}\n"])
def configure(self, options: List[str]):
"""Append lines into postgresql.conf file."""
@@ -1393,12 +1398,12 @@ class VanillaPostgres(PgProtocol):
if log_path is None:
log_path = os.path.join(self.pgdatadir, "pg.log")
self.pg_bin.run_capture(['pg_ctl', '-D', self.pgdatadir, '-l', log_path, 'start'])
self.pg_bin.run_capture(['pg_ctl', '-w', '-D', self.pgdatadir, '-l', log_path, 'start'])
def stop(self):
assert self.running
self.running = False
self.pg_bin.run_capture(['pg_ctl', '-D', self.pgdatadir, 'stop'])
self.pg_bin.run_capture(['pg_ctl', '-w', '-D', self.pgdatadir, 'stop'])
def get_subdir_size(self, subdir) -> int:
"""Return size of pgdatadir subdirectory in bytes."""
@@ -1413,10 +1418,12 @@ class VanillaPostgres(PgProtocol):
@pytest.fixture(scope='function')
def vanilla_pg(test_output_dir: str) -> Iterator[VanillaPostgres]:
def vanilla_pg(test_output_dir: str,
port_distributor: PortDistributor) -> Iterator[VanillaPostgres]:
pgdatadir = os.path.join(test_output_dir, "pgdata-vanilla")
pg_bin = PgBin(test_output_dir)
with VanillaPostgres(pgdatadir, pg_bin, 5432) as vanilla_pg:
port = port_distributor.get_port()
with VanillaPostgres(pgdatadir, pg_bin, port) as vanilla_pg:
yield vanilla_pg
@@ -1462,7 +1469,7 @@ def remote_pg(test_output_dir: str) -> Iterator[RemotePostgres]:
class NeonProxy(PgProtocol):
def __init__(self, port: int):
def __init__(self, port: int, pg_port: int):
super().__init__(host="127.0.0.1",
user="proxy_user",
password="pytest2",
@@ -1471,9 +1478,10 @@ class NeonProxy(PgProtocol):
self.http_port = 7001
self.host = "127.0.0.1"
self.port = port
self.pg_port = pg_port
self._popen: Optional[subprocess.Popen[bytes]] = None
def start_static(self, addr="127.0.0.1:5432") -> None:
def start(self) -> None:
assert self._popen is None
# Start proxy
@@ -1482,7 +1490,8 @@ class NeonProxy(PgProtocol):
args.extend(["--http", f"{self.host}:{self.http_port}"])
args.extend(["--proxy", f"{self.host}:{self.port}"])
args.extend(["--auth-backend", "postgres"])
args.extend(["--auth-endpoint", "postgres://proxy_auth:pytest1@localhost:5432/postgres"])
args.extend(
["--auth-endpoint", f"postgres://proxy_auth:pytest1@localhost:{self.pg_port}/postgres"])
self._popen = subprocess.Popen(args)
self._wait_until_ready()
@@ -1501,14 +1510,16 @@ class NeonProxy(PgProtocol):
@pytest.fixture(scope='function')
def static_proxy(vanilla_pg) -> Iterator[NeonProxy]:
def static_proxy(vanilla_pg, port_distributor) -> Iterator[NeonProxy]:
"""Neon proxy that routes directly to vanilla postgres."""
vanilla_pg.start()
vanilla_pg.safe_psql("create user proxy_auth with password 'pytest1' superuser")
vanilla_pg.safe_psql("create user proxy_user with password 'pytest2'")
with NeonProxy(4432) as proxy:
proxy.start_static()
port = port_distributor.get_port()
pg_port = vanilla_pg.default_options['port']
with NeonProxy(port, pg_port) as proxy:
proxy.start()
yield proxy

View File

@@ -3,7 +3,7 @@ import shutil
import subprocess
from pathlib import Path
from typing import Any, List, Optional
from typing import Any, List
from fixtures.log_helper import log

View File

@@ -0,0 +1,267 @@
import statistics
import threading
import time
import timeit
from typing import Callable
import pytest
from fixtures.benchmark_fixture import MetricReport, NeonBenchmarker
from fixtures.compare_fixtures import NeonCompare, PgCompare, VanillaCompare
from fixtures.log_helper import log
from fixtures.neon_fixtures import DEFAULT_BRANCH_NAME, NeonEnvBuilder, PgBin
from fixtures.utils import lsn_from_hex
from performance.test_perf_pgbench import (get_durations_matrix, get_scales_matrix)
@pytest.fixture(params=["vanilla", "neon_off", "neon_on"])
# This fixture constructs multiple `PgCompare` interfaces using a builder pattern.
# The builder parameters are encoded in the fixture's param.
# For example, to build a `NeonCompare` interface, the corresponding fixture's param should have
# a format of `neon_{safekeepers_enable_fsync}`.
# Note that, here "_" is used to separate builder parameters.
def pg_compare(request) -> PgCompare:
x = request.param.split("_")
if x[0] == "vanilla":
# `VanillaCompare` interface
fixture = request.getfixturevalue("vanilla_compare")
assert isinstance(fixture, VanillaCompare)
return fixture
else:
assert len(x) == 2, f"request param ({request.param}) should have a format of \
`neon_{{safekeepers_enable_fsync}}`"
# `NeonCompare` interface
neon_env_builder = request.getfixturevalue("neon_env_builder")
assert isinstance(neon_env_builder, NeonEnvBuilder)
zenbenchmark = request.getfixturevalue("zenbenchmark")
assert isinstance(zenbenchmark, NeonBenchmarker)
pg_bin = request.getfixturevalue("pg_bin")
assert isinstance(pg_bin, PgBin)
neon_env_builder.safekeepers_enable_fsync = x[1] == "on"
env = neon_env_builder.init_start()
env.neon_cli.create_branch("empty", ancestor_branch_name=DEFAULT_BRANCH_NAME)
branch_name = request.node.name
return NeonCompare(zenbenchmark, env, pg_bin, branch_name)
def start_heavy_write_workload(env: PgCompare, n_tables: int, scale: int, num_iters: int):
"""Start an intensive write workload across multiple tables.
## Single table workload:
At each step, insert new `new_rows_each_update` rows.
The variable `new_rows_each_update` is equal to `scale * 100_000`.
The number of steps is determined by `num_iters` variable."""
new_rows_each_update = scale * 100_000
def start_single_table_workload(table_id: int):
for _ in range(num_iters):
with env.pg.connect().cursor() as cur:
cur.execute(
f"INSERT INTO t{table_id} SELECT FROM generate_series(1,{new_rows_each_update})"
)
with env.record_duration("run_duration"):
threads = [
threading.Thread(target=start_single_table_workload, args=(i, ))
for i in range(n_tables)
]
for thread in threads:
thread.start()
for thread in threads:
thread.join()
@pytest.mark.timeout(1000)
@pytest.mark.parametrize("n_tables", [5])
@pytest.mark.parametrize("scale", get_scales_matrix(5))
@pytest.mark.parametrize("num_iters", [10])
def test_heavy_write_workload(pg_compare: PgCompare, n_tables: int, scale: int, num_iters: int):
env = pg_compare
# Initializes test tables
with env.pg.connect().cursor() as cur:
for i in range(n_tables):
cur.execute(
f"CREATE TABLE t{i}(key serial primary key, t text default 'foooooooooooooooooooooooooooooooooooooooooooooooooooo')"
)
cur.execute(f"INSERT INTO t{i} (key) VALUES (0)")
workload_thread = threading.Thread(target=start_heavy_write_workload,
args=(env, n_tables, scale, num_iters))
workload_thread.start()
record_thread = threading.Thread(target=record_lsn_write_lag,
args=(env, lambda: workload_thread.is_alive()))
record_thread.start()
record_read_latency(env, lambda: workload_thread.is_alive(), "SELECT * from t0 where key = 0")
workload_thread.join()
record_thread.join()
def start_pgbench_simple_update_workload(env: PgCompare, duration: int):
with env.record_duration("run_duration"):
env.pg_bin.run_capture([
'pgbench',
'-j10',
'-c10',
'-N',
f'-T{duration}',
'-Mprepared',
env.pg.connstr(options="-csynchronous_commit=off")
])
env.flush()
@pytest.mark.timeout(1000)
@pytest.mark.parametrize("scale", get_scales_matrix(100))
@pytest.mark.parametrize("duration", get_durations_matrix())
def test_pgbench_simple_update_workload(pg_compare: PgCompare, scale: int, duration: int):
env = pg_compare
# initialize pgbench tables
env.pg_bin.run_capture(['pgbench', f'-s{scale}', '-i', env.pg.connstr()])
env.flush()
workload_thread = threading.Thread(target=start_pgbench_simple_update_workload,
args=(env, duration))
workload_thread.start()
record_thread = threading.Thread(target=record_lsn_write_lag,
args=(env, lambda: workload_thread.is_alive()))
record_thread.start()
record_read_latency(env,
lambda: workload_thread.is_alive(),
"SELECT * from pgbench_accounts where aid = 1")
workload_thread.join()
record_thread.join()
def start_pgbench_intensive_initialization(env: PgCompare, scale: int):
with env.record_duration("run_duration"):
# Needs to increase the statement timeout (default: 120s) because the
# initialization step can be slow with a large scale.
env.pg_bin.run_capture([
'pgbench',
f'-s{scale}',
'-i',
'-Idtg',
env.pg.connstr(options='-cstatement_timeout=300s')
])
@pytest.mark.timeout(1000)
@pytest.mark.parametrize("scale", get_scales_matrix(1000))
def test_pgbench_intensive_init_workload(pg_compare: PgCompare, scale: int):
env = pg_compare
with env.pg.connect().cursor() as cur:
cur.execute("CREATE TABLE foo as select generate_series(1,100000)")
workload_thread = threading.Thread(target=start_pgbench_intensive_initialization,
args=(env, scale))
workload_thread.start()
record_thread = threading.Thread(target=record_lsn_write_lag,
args=(env, lambda: workload_thread.is_alive()))
record_thread.start()
record_read_latency(env, lambda: workload_thread.is_alive(), "SELECT count(*) from foo")
workload_thread.join()
record_thread.join()
def record_lsn_write_lag(env: PgCompare, run_cond: Callable[[], bool], pool_interval: float = 1.0):
if not isinstance(env, NeonCompare):
return
lsn_write_lags = []
last_received_lsn = 0
last_pg_flush_lsn = 0
with env.pg.connect().cursor() as cur:
cur.execute("CREATE EXTENSION neon")
while run_cond():
cur.execute('''
select pg_wal_lsn_diff(pg_current_wal_flush_lsn(),received_lsn),
pg_size_pretty(pg_wal_lsn_diff(pg_current_wal_flush_lsn(),received_lsn)),
pg_current_wal_flush_lsn(),
received_lsn
from backpressure_lsns();
''')
res = cur.fetchone()
lsn_write_lags.append(res[0])
curr_received_lsn = lsn_from_hex(res[3])
lsn_process_speed = (curr_received_lsn - last_received_lsn) / (1024**2)
last_received_lsn = curr_received_lsn
curr_pg_flush_lsn = lsn_from_hex(res[2])
lsn_produce_speed = (curr_pg_flush_lsn - last_pg_flush_lsn) / (1024**2)
last_pg_flush_lsn = curr_pg_flush_lsn
log.info(
f"received_lsn_lag={res[1]}, pg_flush_lsn={res[2]}, received_lsn={res[3]}, lsn_process_speed={lsn_process_speed:.2f}MB/s, lsn_produce_speed={lsn_produce_speed:.2f}MB/s"
)
time.sleep(pool_interval)
env.zenbenchmark.record("lsn_write_lag_max",
float(max(lsn_write_lags) / (1024**2)),
"MB",
MetricReport.LOWER_IS_BETTER)
env.zenbenchmark.record("lsn_write_lag_avg",
float(statistics.mean(lsn_write_lags) / (1024**2)),
"MB",
MetricReport.LOWER_IS_BETTER)
env.zenbenchmark.record("lsn_write_lag_stdev",
float(statistics.stdev(lsn_write_lags) / (1024**2)),
"MB",
MetricReport.LOWER_IS_BETTER)
def record_read_latency(env: PgCompare,
run_cond: Callable[[], bool],
read_query: str,
read_interval: float = 1.0):
read_latencies = []
with env.pg.connect().cursor() as cur:
while run_cond():
try:
t1 = timeit.default_timer()
cur.execute(read_query)
t2 = timeit.default_timer()
log.info(
f"Executed read query {read_query}, got {cur.fetchall()}, read time {t2-t1:.2f}s"
)
read_latencies.append(t2 - t1)
except Exception as err:
log.error(f"Got error when executing the read query: {err}")
time.sleep(read_interval)
env.zenbenchmark.record("read_latency_max",
max(read_latencies),
's',
MetricReport.LOWER_IS_BETTER)
env.zenbenchmark.record("read_latency_avg",
statistics.mean(read_latencies),
's',
MetricReport.LOWER_IS_BETTER)
env.zenbenchmark.record("read_latency_stdev",
statistics.stdev(read_latencies),
's',
MetricReport.LOWER_IS_BETTER)