mirror of
https://github.com/neondatabase/neon.git
synced 2026-06-02 21:10:38 +00:00
Compare commits
50 Commits
tmp-repro
...
arthur/sim
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
6a00ad3aab | ||
|
|
61e6b24cb2 | ||
|
|
44c7d96ed0 | ||
|
|
10ad3ae4eb | ||
|
|
eb2886b401 | ||
|
|
0dc262a84a | ||
|
|
d801ba7248 | ||
|
|
1effb586ba | ||
|
|
2fd351fd63 | ||
|
|
13e94bf687 | ||
|
|
41b9750e81 | ||
|
|
f8729f046d | ||
|
|
420d3bc18f | ||
|
|
33f7877d1b | ||
|
|
7de94c959a | ||
|
|
731ed3bb64 | ||
|
|
413ce2cfe8 | ||
|
|
7f36028fab | ||
|
|
cb6a8d3fe3 | ||
|
|
095747afc0 | ||
|
|
89bd7ab8a3 | ||
|
|
5034a8cca0 | ||
|
|
55e40d090e | ||
|
|
d87e822169 | ||
|
|
296a0cbac2 | ||
|
|
aed14f52d5 | ||
|
|
909d7fadb8 | ||
|
|
3840d6b18b | ||
|
|
65f92232e6 | ||
|
|
0d4f987fc8 | ||
|
|
aa0763d49d | ||
|
|
7b5123edda | ||
|
|
b6a80bc269 | ||
|
|
ac82b34c64 | ||
|
|
a77fc2c5ff | ||
|
|
9ccbec0e14 | ||
|
|
b55005d2c4 | ||
|
|
6436432a77 | ||
|
|
1b8918e665 | ||
|
|
87c9edac7c | ||
|
|
5e0550a620 | ||
|
|
06f493f525 | ||
|
|
f6b540ebfe | ||
|
|
83f87af02b | ||
|
|
79823c38cd | ||
|
|
072fb3d7e9 | ||
|
|
f2fb9f6be9 | ||
|
|
dd4c8fb568 | ||
|
|
9116c01614 | ||
|
|
17cd96e022 |
@@ -14,3 +14,6 @@ opt-level = 1
|
||||
|
||||
[alias]
|
||||
build_testing = ["build", "--features", "testing"]
|
||||
|
||||
[build]
|
||||
rustflags = ["-C", "default-linker-libraries"]
|
||||
|
||||
@@ -19,7 +19,7 @@ inputs:
|
||||
run_in_parallel:
|
||||
description: 'Whether to run tests in parallel'
|
||||
required: false
|
||||
default: 'false'
|
||||
default: 'true'
|
||||
save_perf_report:
|
||||
description: 'Whether to upload the performance report, if true PERF_TEST_RESULT_CONNSTR env variable should be set'
|
||||
required: false
|
||||
@@ -171,7 +171,7 @@ runs:
|
||||
--junitxml=$TEST_OUTPUT/junit.xml \
|
||||
--alluredir=$TEST_OUTPUT/allure/results \
|
||||
--tb=short \
|
||||
--verbose -k "test_forward or test_create_snapsh" -x \
|
||||
--verbose \
|
||||
-rA $TEST_SELECTION $EXTRA_PARAMS
|
||||
|
||||
if [[ "${{ inputs.save_perf_report }}" == "true" ]]; then
|
||||
|
||||
22
.github/ansible/deploy.yaml
vendored
22
.github/ansible/deploy.yaml
vendored
@@ -91,15 +91,6 @@
|
||||
tags:
|
||||
- pageserver
|
||||
|
||||
# used in `pageserver.service` template
|
||||
- name: learn current availability_zone
|
||||
shell:
|
||||
cmd: "curl -s http://169.254.169.254/latest/meta-data/placement/availability-zone"
|
||||
register: ec2_availability_zone
|
||||
|
||||
- set_fact:
|
||||
ec2_availability_zone={{ ec2_availability_zone.stdout }}
|
||||
|
||||
- name: upload systemd service definition
|
||||
ansible.builtin.template:
|
||||
src: systemd/pageserver.service
|
||||
@@ -127,7 +118,7 @@
|
||||
cmd: |
|
||||
INSTANCE_ID=$(curl -s http://169.254.169.254/latest/meta-data/instance-id)
|
||||
curl -sfS -H "Authorization: Bearer {{ CONSOLE_API_TOKEN }}" {{ console_mgmt_base_url }}/management/api/v2/pageservers/$INSTANCE_ID | jq '.version = {{ current_version }}' > /tmp/new_version
|
||||
curl -sfS -H "Authorization: Bearer {{ CONSOLE_API_TOKEN }}" -H "Content-Type: application/json" -X POST -d@/tmp/new_version {{ console_mgmt_base_url }}/management/api/v2/pageservers
|
||||
curl -sfS -H "Authorization: Bearer {{ CONSOLE_API_TOKEN }}" -X POST -d@/tmp/new_version {{ console_mgmt_base_url }}/management/api/v2/pageservers
|
||||
tags:
|
||||
- pageserver
|
||||
|
||||
@@ -162,15 +153,6 @@
|
||||
tags:
|
||||
- safekeeper
|
||||
|
||||
# used in `safekeeper.service` template
|
||||
- name: learn current availability_zone
|
||||
shell:
|
||||
cmd: "curl -s http://169.254.169.254/latest/meta-data/placement/availability-zone"
|
||||
register: ec2_availability_zone
|
||||
|
||||
- set_fact:
|
||||
ec2_availability_zone={{ ec2_availability_zone.stdout }}
|
||||
|
||||
# in the future safekeepers should discover pageservers byself
|
||||
# but currently use first pageserver that was discovered
|
||||
- name: set first pageserver var for safekeepers
|
||||
@@ -206,6 +188,6 @@
|
||||
cmd: |
|
||||
INSTANCE_ID=$(curl -s http://169.254.169.254/latest/meta-data/instance-id)
|
||||
curl -sfS -H "Authorization: Bearer {{ CONSOLE_API_TOKEN }}" {{ console_mgmt_base_url }}/management/api/v2/safekeepers/$INSTANCE_ID | jq '.version = {{ current_version }}' > /tmp/new_version
|
||||
curl -sfS -H "Authorization: Bearer {{ CONSOLE_API_TOKEN }}" -H "Content-Type: application/json" -X POST -d@/tmp/new_version {{ console_mgmt_base_url }}/management/api/v2/safekeepers
|
||||
curl -sfS -H "Authorization: Bearer {{ CONSOLE_API_TOKEN }}" -X POST -d@/tmp/new_version {{ console_mgmt_base_url }}/management/api/v2/safekeepers
|
||||
tags:
|
||||
- safekeeper
|
||||
|
||||
2
.github/ansible/prod.eu-central-1.hosts.yaml
vendored
2
.github/ansible/prod.eu-central-1.hosts.yaml
vendored
@@ -27,8 +27,6 @@ storage:
|
||||
ansible_host: i-0cd8d316ecbb715be
|
||||
pageserver-1.eu-central-1.aws.neon.tech:
|
||||
ansible_host: i-090044ed3d383fef0
|
||||
pageserver-2.eu-central-1.aws.neon.tech:
|
||||
ansible_host: i-033584edf3f4b6742
|
||||
|
||||
safekeepers:
|
||||
hosts:
|
||||
|
||||
2
.github/ansible/scripts/init_pageserver.sh
vendored
2
.github/ansible/scripts/init_pageserver.sh
vendored
@@ -26,7 +26,7 @@ EOF
|
||||
if ! curl -sf -H "Authorization: Bearer {{ CONSOLE_API_TOKEN }}" {{ console_mgmt_base_url }}/management/api/v2/pageservers/${INSTANCE_ID} -o /dev/null; then
|
||||
|
||||
# not registered, so register it now
|
||||
ID=$(curl -sf -X POST -H "Authorization: Bearer {{ CONSOLE_API_TOKEN }}" -H "Content-Type: application/json" {{ console_mgmt_base_url }}/management/api/v2/pageservers -d@/tmp/payload | jq -r '.id')
|
||||
ID=$(curl -sf -X POST -H "Authorization: Bearer {{ CONSOLE_API_TOKEN }}" {{ console_mgmt_base_url }}/management/api/v2/pageservers -d@/tmp/payload | jq -r '.id')
|
||||
|
||||
# init pageserver
|
||||
sudo -u pageserver /usr/local/bin/pageserver -c "id=${ID}" -c "pg_distrib_dir='/usr/local'" --init -D /storage/pageserver/data
|
||||
|
||||
2
.github/ansible/scripts/init_safekeeper.sh
vendored
2
.github/ansible/scripts/init_safekeeper.sh
vendored
@@ -25,7 +25,7 @@ EOF
|
||||
if ! curl -sf -H "Authorization: Bearer {{ CONSOLE_API_TOKEN }}" {{ console_mgmt_base_url }}/management/api/v2/safekeepers/${INSTANCE_ID} -o /dev/null; then
|
||||
|
||||
# not registered, so register it now
|
||||
ID=$(curl -sf -X POST -H "Authorization: Bearer {{ CONSOLE_API_TOKEN }}" -H "Content-Type: application/json" {{ console_mgmt_base_url }}/management/api/v2/safekeepers -d@/tmp/payload | jq -r '.id')
|
||||
ID=$(curl -sf -X POST -H "Authorization: Bearer {{ CONSOLE_API_TOKEN }}" {{ console_mgmt_base_url }}/management/api/v2/safekeepers -d@/tmp/payload | jq -r '.id')
|
||||
# init safekeeper
|
||||
sudo -u safekeeper /usr/local/bin/safekeeper --id ${ID} --init -D /storage/safekeeper/data
|
||||
fi
|
||||
|
||||
2
.github/ansible/systemd/pageserver.service
vendored
2
.github/ansible/systemd/pageserver.service
vendored
@@ -6,7 +6,7 @@ After=network.target auditd.service
|
||||
Type=simple
|
||||
User=pageserver
|
||||
Environment=RUST_BACKTRACE=1 NEON_REPO_DIR=/storage/pageserver LD_LIBRARY_PATH=/usr/local/v14/lib SENTRY_DSN={{ SENTRY_URL_PAGESERVER }} SENTRY_ENVIRONMENT={{ sentry_environment }}
|
||||
ExecStart=/usr/local/bin/pageserver -c "pg_distrib_dir='/usr/local'" -c "listen_pg_addr='0.0.0.0:6400'" -c "listen_http_addr='0.0.0.0:9898'" -c "broker_endpoint='{{ broker_endpoint }}'" -c "availability_zone='{{ ec2_availability_zone }}'" -D /storage/pageserver/data
|
||||
ExecStart=/usr/local/bin/pageserver -c "pg_distrib_dir='/usr/local'" -c "listen_pg_addr='0.0.0.0:6400'" -c "listen_http_addr='0.0.0.0:9898'" -c "broker_endpoint='{{ broker_endpoint }}'" -D /storage/pageserver/data
|
||||
ExecReload=/bin/kill -HUP $MAINPID
|
||||
KillMode=mixed
|
||||
KillSignal=SIGINT
|
||||
|
||||
2
.github/ansible/systemd/safekeeper.service
vendored
2
.github/ansible/systemd/safekeeper.service
vendored
@@ -6,7 +6,7 @@ After=network.target auditd.service
|
||||
Type=simple
|
||||
User=safekeeper
|
||||
Environment=RUST_BACKTRACE=1 NEON_REPO_DIR=/storage/safekeeper/data LD_LIBRARY_PATH=/usr/local/v14/lib SENTRY_DSN={{ SENTRY_URL_SAFEKEEPER }} SENTRY_ENVIRONMENT={{ sentry_environment }}
|
||||
ExecStart=/usr/local/bin/safekeeper -l {{ inventory_hostname }}{{ hostname_suffix }}:6500 --listen-http {{ inventory_hostname }}{{ hostname_suffix }}:7676 -D /storage/safekeeper/data --broker-endpoint={{ broker_endpoint }} --remote-storage='{bucket_name="{{bucket_name}}", bucket_region="{{bucket_region}}", prefix_in_bucket="{{ safekeeper_s3_prefix }}"}' --availability-zone={{ ec2_availability_zone }}
|
||||
ExecStart=/usr/local/bin/safekeeper -l {{ inventory_hostname }}{{ hostname_suffix }}:6500 --listen-http {{ inventory_hostname }}{{ hostname_suffix }}:7676 -D /storage/safekeeper/data --broker-endpoint={{ broker_endpoint }} --remote-storage='{bucket_name="{{bucket_name}}", bucket_region="{{bucket_region}}", prefix_in_bucket="{{ safekeeper_s3_prefix }}"}'
|
||||
ExecReload=/bin/kill -HUP $MAINPID
|
||||
KillMode=mixed
|
||||
KillSignal=SIGINT
|
||||
|
||||
@@ -1,22 +1,6 @@
|
||||
# Helm chart values for neon-proxy-scram.
|
||||
# This is a YAML-formatted file.
|
||||
|
||||
deploymentStrategy:
|
||||
type: RollingUpdate
|
||||
rollingUpdate:
|
||||
maxSurge: 100%
|
||||
maxUnavailable: 50%
|
||||
|
||||
# Delay the kill signal by 7 days (7 * 24 * 60 * 60)
|
||||
# The pod(s) will stay in Terminating, keeps the existing connections
|
||||
# but doesn't receive new ones
|
||||
containerLifecycle:
|
||||
preStop:
|
||||
exec:
|
||||
command: ["/bin/sh", "-c", "sleep 604800"]
|
||||
terminationGracePeriodSeconds: 604800
|
||||
|
||||
|
||||
image:
|
||||
repository: neondatabase/neon
|
||||
|
||||
|
||||
50
.github/workflows/build_and_test.yml
vendored
50
.github/workflows/build_and_test.yml
vendored
@@ -5,7 +5,6 @@ on:
|
||||
branches:
|
||||
- main
|
||||
- release
|
||||
- tmp-repro
|
||||
pull_request:
|
||||
|
||||
defaults:
|
||||
@@ -75,12 +74,15 @@ jobs:
|
||||
- name: Install Python deps
|
||||
run: ./scripts/pysync
|
||||
|
||||
- name: Run ruff to ensure code format
|
||||
run: poetry run ruff .
|
||||
- name: Run isort to ensure code format
|
||||
run: poetry run isort --diff --check .
|
||||
|
||||
- name: Run black to ensure code format
|
||||
run: poetry run black --diff --check .
|
||||
|
||||
- name: Run flake8 to ensure code format
|
||||
run: poetry run flake8 .
|
||||
|
||||
- name: Run mypy to check types
|
||||
run: poetry run mypy .
|
||||
|
||||
@@ -549,48 +551,6 @@ jobs:
|
||||
- name: Cleanup ECR folder
|
||||
run: rm -rf ~/.ecr
|
||||
|
||||
|
||||
neon-image-depot:
|
||||
# For testing this will run side-by-side for a few merges.
|
||||
# This action is not really optimized yet, but gets the job done
|
||||
runs-on: [ self-hosted, gen3, small ]
|
||||
needs: [ tag ]
|
||||
container: 369495373322.dkr.ecr.eu-central-1.amazonaws.com/base:pinned
|
||||
permissions:
|
||||
contents: read
|
||||
id-token: write
|
||||
|
||||
steps:
|
||||
- name: Checkout
|
||||
uses: actions/checkout@v3
|
||||
with:
|
||||
submodules: true
|
||||
fetch-depth: 0
|
||||
|
||||
- name: Setup go
|
||||
uses: actions/setup-go@v3
|
||||
with:
|
||||
go-version: '1.19'
|
||||
|
||||
- name: Set up Depot CLI
|
||||
uses: depot/setup-action@v1
|
||||
|
||||
- name: Install Crane & ECR helper
|
||||
run: go install github.com/awslabs/amazon-ecr-credential-helper/ecr-login/cli/docker-credential-ecr-login@69c85dc22db6511932bbf119e1a0cc5c90c69a7f # v0.6.0
|
||||
|
||||
- name: Configure ECR login
|
||||
run: |
|
||||
mkdir /github/home/.docker/
|
||||
echo "{\"credsStore\":\"ecr-login\"}" > /github/home/.docker/config.json
|
||||
|
||||
- name: Build and push
|
||||
uses: depot/build-push-action@v1
|
||||
with:
|
||||
# if no depot.json file is at the root of your repo, you must specify the project id
|
||||
project: nrdv0s4kcs
|
||||
push: true
|
||||
tags: 369495373322.dkr.ecr.eu-central-1.amazonaws.com/neon:depot-${{needs.tag.outputs.build-tag}}
|
||||
|
||||
compute-tools-image:
|
||||
runs-on: [ self-hosted, gen3, large ]
|
||||
needs: [ tag ]
|
||||
|
||||
1
.github/workflows/release.yml
vendored
1
.github/workflows/release.yml
vendored
@@ -31,4 +31,3 @@ jobs:
|
||||
head: releases/${{ steps.date.outputs.date }}
|
||||
base: release
|
||||
title: Release ${{ steps.date.outputs.date }}
|
||||
team_reviewers: release
|
||||
|
||||
2
.gitignore
vendored
2
.gitignore
vendored
@@ -18,3 +18,5 @@ test_output/
|
||||
*.o
|
||||
*.so
|
||||
*.Po
|
||||
|
||||
tmp
|
||||
|
||||
133
Cargo.lock
generated
133
Cargo.lock
generated
@@ -679,6 +679,25 @@ version = "0.3.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "37b2a672a2cb129a2e41c10b1224bb368f9f37a2b16b612598138befd7b37eb5"
|
||||
|
||||
[[package]]
|
||||
name = "cbindgen"
|
||||
version = "0.24.5"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "4b922faaf31122819ec80c4047cc684c6979a087366c069611e33649bf98e18d"
|
||||
dependencies = [
|
||||
"clap 3.2.23",
|
||||
"heck",
|
||||
"indexmap",
|
||||
"log",
|
||||
"proc-macro2",
|
||||
"quote",
|
||||
"serde",
|
||||
"serde_json",
|
||||
"syn",
|
||||
"tempfile",
|
||||
"toml",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "cc"
|
||||
version = "1.0.79"
|
||||
@@ -757,9 +776,12 @@ version = "3.2.23"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "71655c45cb9845d3270c9d6df84ebe72b4dad3c2ba3f7023ad47c144e4e473a5"
|
||||
dependencies = [
|
||||
"atty",
|
||||
"bitflags",
|
||||
"clap_lex 0.2.4",
|
||||
"indexmap",
|
||||
"strsim",
|
||||
"termcolor",
|
||||
"textwrap",
|
||||
]
|
||||
|
||||
@@ -851,7 +873,6 @@ dependencies = [
|
||||
"futures",
|
||||
"hyper",
|
||||
"notify",
|
||||
"num_cpus",
|
||||
"opentelemetry",
|
||||
"postgres",
|
||||
"regex",
|
||||
@@ -914,7 +935,6 @@ dependencies = [
|
||||
"once_cell",
|
||||
"pageserver_api",
|
||||
"postgres",
|
||||
"postgres_backend",
|
||||
"postgres_connection",
|
||||
"regex",
|
||||
"reqwest",
|
||||
@@ -1016,6 +1036,20 @@ version = "1.1.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "6548a0ad5d2549e111e1f6a11a6c2e2d00ce6a3dafe22948d67c2b443f775e52"
|
||||
|
||||
[[package]]
|
||||
name = "crossbeam"
|
||||
version = "0.8.2"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "2801af0d36612ae591caa9568261fddce32ce6e08a7275ea334a06a4ad021a2c"
|
||||
dependencies = [
|
||||
"cfg-if",
|
||||
"crossbeam-channel",
|
||||
"crossbeam-deque",
|
||||
"crossbeam-epoch",
|
||||
"crossbeam-queue",
|
||||
"crossbeam-utils",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "crossbeam-channel"
|
||||
version = "0.5.6"
|
||||
@@ -1050,6 +1084,16 @@ dependencies = [
|
||||
"scopeguard",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "crossbeam-queue"
|
||||
version = "0.3.8"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "d1cfb3ea8a53f37c40dea2c7bedcbd88bdfae54f5e2175d6ecaff1c988353add"
|
||||
dependencies = [
|
||||
"cfg-if",
|
||||
"crossbeam-utils",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "crossbeam-utils"
|
||||
version = "0.8.14"
|
||||
@@ -2456,7 +2500,6 @@ dependencies = [
|
||||
"postgres",
|
||||
"postgres-protocol",
|
||||
"postgres-types",
|
||||
"postgres_backend",
|
||||
"postgres_connection",
|
||||
"postgres_ffi",
|
||||
"pq_proto",
|
||||
@@ -2679,28 +2722,6 @@ dependencies = [
|
||||
"postgres-protocol",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "postgres_backend"
|
||||
version = "0.1.0"
|
||||
dependencies = [
|
||||
"anyhow",
|
||||
"async-trait",
|
||||
"bytes",
|
||||
"futures",
|
||||
"once_cell",
|
||||
"pq_proto",
|
||||
"rustls",
|
||||
"rustls-pemfile",
|
||||
"serde",
|
||||
"thiserror",
|
||||
"tokio",
|
||||
"tokio-postgres",
|
||||
"tokio-postgres-rustls",
|
||||
"tokio-rustls",
|
||||
"tracing",
|
||||
"workspace_hack",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "postgres_connection"
|
||||
version = "0.1.0"
|
||||
@@ -2748,7 +2769,7 @@ checksum = "5b40af805b3121feab8a3c29f04d8ad262fa8e0561883e7653e024ae4479e6de"
|
||||
name = "pq_proto"
|
||||
version = "0.1.0"
|
||||
dependencies = [
|
||||
"byteorder",
|
||||
"anyhow",
|
||||
"bytes",
|
||||
"pin-project-lite",
|
||||
"postgres-protocol",
|
||||
@@ -2923,7 +2944,6 @@ dependencies = [
|
||||
"opentelemetry",
|
||||
"parking_lot",
|
||||
"pin-project-lite",
|
||||
"postgres_backend",
|
||||
"pq_proto",
|
||||
"prometheus",
|
||||
"rand",
|
||||
@@ -3303,6 +3323,15 @@ dependencies = [
|
||||
"base64 0.21.0",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "rustls-split"
|
||||
version = "0.3.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "78802c9612b4689d207acff746f38132ca1b12dadb55d471aa5f10fd580f47d3"
|
||||
dependencies = [
|
||||
"rustls",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "rustversion"
|
||||
version = "1.0.11"
|
||||
@@ -3328,22 +3357,25 @@ dependencies = [
|
||||
"clap 4.1.4",
|
||||
"const_format",
|
||||
"crc32c",
|
||||
"crossbeam",
|
||||
"fs2",
|
||||
"git-version",
|
||||
"hex",
|
||||
"humantime",
|
||||
"hyper",
|
||||
"metrics",
|
||||
"nix",
|
||||
"once_cell",
|
||||
"parking_lot",
|
||||
"postgres",
|
||||
"postgres-protocol",
|
||||
"postgres_backend",
|
||||
"postgres_ffi",
|
||||
"pq_proto",
|
||||
"rand",
|
||||
"regex",
|
||||
"remote_storage",
|
||||
"safekeeper_api",
|
||||
"scopeguard",
|
||||
"serde",
|
||||
"serde_json",
|
||||
"serde_with",
|
||||
@@ -4524,6 +4556,7 @@ dependencies = [
|
||||
"bytes",
|
||||
"criterion",
|
||||
"futures",
|
||||
"git-version",
|
||||
"heapless",
|
||||
"hex",
|
||||
"hex-literal",
|
||||
@@ -4532,9 +4565,12 @@ dependencies = [
|
||||
"metrics",
|
||||
"nix",
|
||||
"once_cell",
|
||||
"pin-project-lite",
|
||||
"pq_proto",
|
||||
"rand",
|
||||
"routerify",
|
||||
"rustls",
|
||||
"rustls-pemfile",
|
||||
"rustls-split",
|
||||
"sentry",
|
||||
"serde",
|
||||
"serde_json",
|
||||
@@ -4545,6 +4581,7 @@ dependencies = [
|
||||
"tempfile",
|
||||
"thiserror",
|
||||
"tokio",
|
||||
"tokio-rustls",
|
||||
"tracing",
|
||||
"tracing-subscriber",
|
||||
"url",
|
||||
@@ -4600,6 +4637,38 @@ dependencies = [
|
||||
"winapi-util",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "walproposer"
|
||||
version = "0.1.0"
|
||||
dependencies = [
|
||||
"anyhow",
|
||||
"atty",
|
||||
"bindgen",
|
||||
"byteorder",
|
||||
"bytes",
|
||||
"cbindgen",
|
||||
"crc32c",
|
||||
"env_logger",
|
||||
"hex",
|
||||
"hyper",
|
||||
"libc",
|
||||
"log",
|
||||
"memoffset 0.8.0",
|
||||
"once_cell",
|
||||
"postgres",
|
||||
"postgres_ffi",
|
||||
"rand",
|
||||
"regex",
|
||||
"safekeeper",
|
||||
"scopeguard",
|
||||
"serde",
|
||||
"thiserror",
|
||||
"tracing",
|
||||
"tracing-subscriber",
|
||||
"utils",
|
||||
"workspace_hack",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "want"
|
||||
version = "0.3.0"
|
||||
@@ -4848,19 +4917,14 @@ name = "workspace_hack"
|
||||
version = "0.1.0"
|
||||
dependencies = [
|
||||
"anyhow",
|
||||
"byteorder",
|
||||
"bytes",
|
||||
"chrono",
|
||||
"clap 4.1.4",
|
||||
"crossbeam-utils",
|
||||
"digest",
|
||||
"either",
|
||||
"fail",
|
||||
"futures",
|
||||
"futures-channel",
|
||||
"futures-core",
|
||||
"futures-executor",
|
||||
"futures-sink",
|
||||
"futures-util",
|
||||
"hashbrown 0.12.3",
|
||||
"indexmap",
|
||||
@@ -4885,7 +4949,6 @@ dependencies = [
|
||||
"socket2",
|
||||
"syn",
|
||||
"tokio",
|
||||
"tokio-rustls",
|
||||
"tokio-util",
|
||||
"tonic",
|
||||
"tower",
|
||||
|
||||
@@ -64,7 +64,6 @@ md5 = "0.7.0"
|
||||
memoffset = "0.8"
|
||||
nix = "0.26"
|
||||
notify = "5.0.0"
|
||||
num_cpus = "1.15"
|
||||
num-traits = "0.2.15"
|
||||
once_cell = "1.13"
|
||||
opentelemetry = "0.18.0"
|
||||
@@ -134,16 +133,17 @@ heapless = { default-features=false, features=[], git = "https://github.com/japa
|
||||
consumption_metrics = { version = "0.1", path = "./libs/consumption_metrics/" }
|
||||
metrics = { version = "0.1", path = "./libs/metrics/" }
|
||||
pageserver_api = { version = "0.1", path = "./libs/pageserver_api/" }
|
||||
postgres_backend = { version = "0.1", path = "./libs/postgres_backend/" }
|
||||
postgres_connection = { version = "0.1", path = "./libs/postgres_connection/" }
|
||||
postgres_ffi = { version = "0.1", path = "./libs/postgres_ffi/" }
|
||||
pq_proto = { version = "0.1", path = "./libs/pq_proto/" }
|
||||
remote_storage = { version = "0.1", path = "./libs/remote_storage/" }
|
||||
safekeeper_api = { version = "0.1", path = "./libs/safekeeper_api" }
|
||||
safekeeper = { path = "./safekeeper/" }
|
||||
storage_broker = { version = "0.1", path = "./storage_broker/" } # Note: main broker code is inside the binary crate, so linking with the library shouldn't be heavy.
|
||||
tenant_size_model = { version = "0.1", path = "./libs/tenant_size_model/" }
|
||||
tracing-utils = { version = "0.1", path = "./libs/tracing-utils/" }
|
||||
utils = { version = "0.1", path = "./libs/utils/" }
|
||||
walproposer = { version = "0.1", path = "./libs/walproposer/" }
|
||||
|
||||
## Common library dependency
|
||||
workspace_hack = { version = "0.1", path = "./workspace_hack/" }
|
||||
|
||||
@@ -39,7 +39,7 @@ ARG CACHEPOT_BUCKET=neon-github-dev
|
||||
|
||||
COPY --from=pg-build /home/nonroot/pg_install/v14/include/postgresql/server pg_install/v14/include/postgresql/server
|
||||
COPY --from=pg-build /home/nonroot/pg_install/v15/include/postgresql/server pg_install/v15/include/postgresql/server
|
||||
COPY --chown=nonroot . .
|
||||
COPY . .
|
||||
|
||||
# Show build caching stats to check if it was used in the end.
|
||||
# Has to be the part of the same RUN since cachepot daemon is killed in the end of this RUN, losing the compilation stats.
|
||||
|
||||
@@ -255,51 +255,6 @@ RUN wget https://github.com/theory/pgtap/archive/refs/tags/v1.2.0.tar.gz -O pgta
|
||||
make -j $(getconf _NPROCESSORS_ONLN) install PG_CONFIG=/usr/local/pgsql/bin/pg_config && \
|
||||
echo 'trusted = true' >> /usr/local/pgsql/share/extension/pgtap.control
|
||||
|
||||
#########################################################################################
|
||||
#
|
||||
# Layer "prefix-pg-build"
|
||||
# compile Prefix extension
|
||||
#
|
||||
#########################################################################################
|
||||
FROM build-deps AS prefix-pg-build
|
||||
COPY --from=pg-build /usr/local/pgsql/ /usr/local/pgsql/
|
||||
|
||||
RUN wget https://github.com/dimitri/prefix/archive/refs/tags/v1.2.9.tar.gz -O prefix.tar.gz && \
|
||||
mkdir prefix-src && cd prefix-src && tar xvzf ../prefix.tar.gz --strip-components=1 -C . && \
|
||||
make -j $(getconf _NPROCESSORS_ONLN) PG_CONFIG=/usr/local/pgsql/bin/pg_config && \
|
||||
make -j $(getconf _NPROCESSORS_ONLN) install PG_CONFIG=/usr/local/pgsql/bin/pg_config && \
|
||||
echo 'trusted = true' >> /usr/local/pgsql/share/extension/prefix.control
|
||||
|
||||
#########################################################################################
|
||||
#
|
||||
# Layer "hll-pg-build"
|
||||
# compile hll extension
|
||||
#
|
||||
#########################################################################################
|
||||
FROM build-deps AS hll-pg-build
|
||||
COPY --from=pg-build /usr/local/pgsql/ /usr/local/pgsql/
|
||||
|
||||
RUN wget https://github.com/citusdata/postgresql-hll/archive/refs/tags/v2.17.tar.gz -O hll.tar.gz && \
|
||||
mkdir hll-src && cd hll-src && tar xvzf ../hll.tar.gz --strip-components=1 -C . && \
|
||||
make -j $(getconf _NPROCESSORS_ONLN) PG_CONFIG=/usr/local/pgsql/bin/pg_config && \
|
||||
make -j $(getconf _NPROCESSORS_ONLN) install PG_CONFIG=/usr/local/pgsql/bin/pg_config && \
|
||||
echo 'trusted = true' >> /usr/local/pgsql/share/extension/hll.control
|
||||
|
||||
#########################################################################################
|
||||
#
|
||||
# Layer "plpgsql-check-pg-build"
|
||||
# compile plpgsql_check extension
|
||||
#
|
||||
#########################################################################################
|
||||
FROM build-deps AS plpgsql-check-pg-build
|
||||
COPY --from=pg-build /usr/local/pgsql/ /usr/local/pgsql/
|
||||
|
||||
RUN wget https://github.com/okbob/plpgsql_check/archive/refs/tags/v2.3.2.tar.gz -O plpgsql_check.tar.gz && \
|
||||
mkdir plpgsql_check-src && cd plpgsql_check-src && tar xvzf ../plpgsql_check.tar.gz --strip-components=1 -C . && \
|
||||
make -j $(getconf _NPROCESSORS_ONLN) PG_CONFIG=/usr/local/pgsql/bin/pg_config USE_PGXS=1 && \
|
||||
make -j $(getconf _NPROCESSORS_ONLN) install PG_CONFIG=/usr/local/pgsql/bin/pg_config USE_PGXS=1 && \
|
||||
echo 'trusted = true' >> /usr/local/pgsql/share/extension/plpgsql_check.control
|
||||
|
||||
#########################################################################################
|
||||
#
|
||||
# Layer "rust extensions"
|
||||
@@ -323,7 +278,7 @@ RUN curl -sSO https://static.rust-lang.org/rustup/dist/$(uname -m)-unknown-linux
|
||||
chmod +x rustup-init && \
|
||||
./rustup-init -y --no-modify-path --profile minimal --default-toolchain stable && \
|
||||
rm rustup-init && \
|
||||
cargo install --locked --version 0.7.3 cargo-pgx && \
|
||||
cargo install --git https://github.com/vadim2404/pgx --branch neon_abi_v0.6.1 --locked cargo-pgx && \
|
||||
/bin/bash -c 'cargo pgx init --pg${PG_VERSION:1}=/usr/local/pgsql/bin/pg_config'
|
||||
|
||||
USER root
|
||||
@@ -337,11 +292,11 @@ USER root
|
||||
|
||||
FROM rust-extensions-build AS pg-jsonschema-pg-build
|
||||
|
||||
# there is no release tag yet, but we need it due to the superuser fix in the control file
|
||||
RUN wget https://github.com/supabase/pg_jsonschema/archive/caeab60d70b2fd3ae421ec66466a3abbb37b7ee6.tar.gz -O pg_jsonschema.tar.gz && \
|
||||
mkdir pg_jsonschema-src && cd pg_jsonschema-src && tar xvzf ../pg_jsonschema.tar.gz --strip-components=1 -C . && \
|
||||
sed -i 's/pgx = "0.7.1"/pgx = { version = "0.7.3", features = [ "unsafe-postgres" ] }/g' Cargo.toml && \
|
||||
RUN git clone --depth=1 --single-branch --branch neon_abi_v0.1.4 https://github.com/vadim2404/pg_jsonschema/ && \
|
||||
cd pg_jsonschema && \
|
||||
cargo pgx install --release && \
|
||||
# it's needed to enable extension because it uses untrusted C language
|
||||
sed -i 's/superuser = false/superuser = true/g' /usr/local/pgsql/share/extension/pg_jsonschema.control && \
|
||||
echo "trusted = true" >> /usr/local/pgsql/share/extension/pg_jsonschema.control
|
||||
|
||||
#########################################################################################
|
||||
@@ -353,32 +308,13 @@ RUN wget https://github.com/supabase/pg_jsonschema/archive/caeab60d70b2fd3ae421e
|
||||
|
||||
FROM rust-extensions-build AS pg-graphql-pg-build
|
||||
|
||||
# Currently pgx version bump to >= 0.7.2 causes "call to unsafe function" compliation errors in
|
||||
# pgx-contrib-spiext. There is a branch that removes that dependency, so use it. It is on the
|
||||
# same 1.1 version we've used before.
|
||||
RUN git clone -b remove-pgx-contrib-spiext --single-branch https://github.com/yrashk/pg_graphql && \
|
||||
cd pg_graphql && \
|
||||
sed -i 's/pgx = "~0.7.1"/pgx = { version = "0.7.3", features = [ "unsafe-postgres" ] }/g' Cargo.toml && \
|
||||
sed -i 's/pgx-tests = "~0.7.1"/pgx-tests = "0.7.3"/g' Cargo.toml && \
|
||||
RUN git clone --depth=1 --single-branch --branch neon_abi_v1.1.0 https://github.com/vadim2404/pg_graphql && \
|
||||
cd pg_graphql && \
|
||||
cargo pgx install --release && \
|
||||
# it's needed to enable extension because it uses untrusted C language
|
||||
sed -i 's/superuser = false/superuser = true/g' /usr/local/pgsql/share/extension/pg_graphql.control && \
|
||||
echo "trusted = true" >> /usr/local/pgsql/share/extension/pg_graphql.control
|
||||
|
||||
#########################################################################################
|
||||
#
|
||||
# Layer "pg-tiktoken-build"
|
||||
# Compile "pg_tiktoken" extension
|
||||
#
|
||||
#########################################################################################
|
||||
|
||||
FROM rust-extensions-build AS pg-tiktoken-pg-build
|
||||
|
||||
RUN git clone --depth=1 --single-branch https://github.com/kelvich/pg_tiktoken && \
|
||||
cd pg_tiktoken && \
|
||||
cargo pgx install --release && \
|
||||
echo "trusted = true" >> /usr/local/pgsql/share/extension/pg_tiktoken.control
|
||||
|
||||
#########################################################################################
|
||||
#
|
||||
# Layer "neon-pg-ext-build"
|
||||
@@ -396,23 +332,15 @@ COPY --from=vector-pg-build /usr/local/pgsql/ /usr/local/pgsql/
|
||||
COPY --from=pgjwt-pg-build /usr/local/pgsql/ /usr/local/pgsql/
|
||||
COPY --from=pg-jsonschema-pg-build /usr/local/pgsql/ /usr/local/pgsql/
|
||||
COPY --from=pg-graphql-pg-build /usr/local/pgsql/ /usr/local/pgsql/
|
||||
COPY --from=pg-tiktoken-pg-build /usr/local/pgsql/ /usr/local/pgsql/
|
||||
COPY --from=hypopg-pg-build /usr/local/pgsql/ /usr/local/pgsql/
|
||||
COPY --from=pg-hashids-pg-build /usr/local/pgsql/ /usr/local/pgsql/
|
||||
COPY --from=rum-pg-build /usr/local/pgsql/ /usr/local/pgsql/
|
||||
COPY --from=pgtap-pg-build /usr/local/pgsql/ /usr/local/pgsql/
|
||||
COPY --from=prefix-pg-build /usr/local/pgsql/ /usr/local/pgsql/
|
||||
COPY --from=hll-pg-build /usr/local/pgsql/ /usr/local/pgsql/
|
||||
COPY --from=plpgsql-check-pg-build /usr/local/pgsql/ /usr/local/pgsql/
|
||||
COPY pgxn/ pgxn/
|
||||
|
||||
RUN make -j $(getconf _NPROCESSORS_ONLN) \
|
||||
PG_CONFIG=/usr/local/pgsql/bin/pg_config \
|
||||
-C pgxn/neon \
|
||||
-s install && \
|
||||
make -j $(getconf _NPROCESSORS_ONLN) \
|
||||
PG_CONFIG=/usr/local/pgsql/bin/pg_config \
|
||||
-C pgxn/neon_utils \
|
||||
-s install
|
||||
|
||||
#########################################################################################
|
||||
@@ -467,7 +395,7 @@ COPY --from=compute-tools --chown=postgres /home/nonroot/target/release-line-deb
|
||||
|
||||
# Install:
|
||||
# libreadline8 for psql
|
||||
# libicu67, locales for collations (including ICU and plpgsql_check)
|
||||
# libicu67, locales for collations (including ICU)
|
||||
# libossp-uuid16 for extension ossp-uuid
|
||||
# libgeos, libgdal, libsfcgal1, libproj and libprotobuf-c1 for PostGIS
|
||||
# libxml2, libxslt1.1 for xml2
|
||||
|
||||
@@ -1,70 +1,25 @@
|
||||
# Note: this file *mostly* just builds on Dockerfile.compute-node
|
||||
|
||||
ARG SRC_IMAGE
|
||||
ARG VM_INFORMANT_VERSION=v0.1.14
|
||||
# on libcgroup update, make sure to check bootstrap.sh for changes
|
||||
ARG LIBCGROUP_VERSION=v2.0.3
|
||||
ARG VM_INFORMANT_VERSION=v0.1.6
|
||||
|
||||
# Pull VM informant, to copy from later
|
||||
# Pull VM informant and set up inittab
|
||||
FROM neondatabase/vm-informant:$VM_INFORMANT_VERSION as informant
|
||||
|
||||
# Build cgroup-tools
|
||||
#
|
||||
# At time of writing (2023-03-14), debian bullseye has a version of cgroup-tools (technically
|
||||
# libcgroup) that doesn't support cgroup v2 (version 0.41-11). Unfortunately, the vm-informant
|
||||
# requires cgroup v2, so we'll build cgroup-tools ourselves.
|
||||
FROM debian:bullseye-slim as libcgroup-builder
|
||||
ARG LIBCGROUP_VERSION
|
||||
|
||||
RUN set -exu \
|
||||
&& apt update \
|
||||
&& apt install --no-install-recommends -y \
|
||||
git \
|
||||
ca-certificates \
|
||||
automake \
|
||||
cmake \
|
||||
make \
|
||||
gcc \
|
||||
byacc \
|
||||
flex \
|
||||
libtool \
|
||||
libpam0g-dev \
|
||||
&& git clone --depth 1 -b $LIBCGROUP_VERSION https://github.com/libcgroup/libcgroup \
|
||||
&& INSTALL_DIR="/libcgroup-install" \
|
||||
&& mkdir -p "$INSTALL_DIR/bin" "$INSTALL_DIR/include" \
|
||||
&& cd libcgroup \
|
||||
# extracted from bootstrap.sh, with modified flags:
|
||||
&& (test -d m4 || mkdir m4) \
|
||||
&& autoreconf -fi \
|
||||
&& rm -rf autom4te.cache \
|
||||
&& CFLAGS="-O3" ./configure --prefix="$INSTALL_DIR" --sysconfdir=/etc --localstatedir=/var --enable-opaque-hierarchy="name=systemd" \
|
||||
# actually build the thing...
|
||||
&& make install
|
||||
|
||||
# Combine, starting from non-VM compute node image.
|
||||
FROM $SRC_IMAGE as base
|
||||
|
||||
# Temporarily set user back to root so we can run adduser, set inittab
|
||||
USER root
|
||||
RUN adduser vm-informant --disabled-password --no-create-home
|
||||
|
||||
RUN set -e \
|
||||
&& rm -f /etc/inittab \
|
||||
&& touch /etc/inittab
|
||||
|
||||
RUN set -e \
|
||||
&& echo "::sysinit:cgconfigparser -l /etc/cgconfig.conf -s 1664" >> /etc/inittab \
|
||||
&& CONNSTR="dbname=neondb user=cloud_admin sslmode=disable" \
|
||||
&& ARGS="--auto-restart --cgroup=neon-postgres --pgconnstr=\"$CONNSTR\"" \
|
||||
&& echo "::respawn:su vm-informant -c '/usr/local/bin/vm-informant $ARGS'" >> /etc/inittab
|
||||
&& echo "::respawn:su vm-informant -c '/usr/local/bin/vm-informant --auto-restart'" >> /etc/inittab
|
||||
|
||||
# Combine, starting from non-VM compute node image.
|
||||
FROM $SRC_IMAGE as base
|
||||
|
||||
# Temporarily set user back to root so we can run adduser
|
||||
USER root
|
||||
RUN adduser vm-informant --disabled-password --no-create-home
|
||||
USER postgres
|
||||
|
||||
ADD vm-cgconfig.conf /etc/cgconfig.conf
|
||||
COPY --from=informant /etc/inittab /etc/inittab
|
||||
COPY --from=informant /usr/bin/vm-informant /usr/local/bin/vm-informant
|
||||
|
||||
COPY --from=libcgroup-builder /libcgroup-install/bin/* /usr/bin/
|
||||
COPY --from=libcgroup-builder /libcgroup-install/lib/* /usr/lib/
|
||||
COPY --from=libcgroup-builder /libcgroup-install/sbin/* /usr/sbin/
|
||||
|
||||
ENTRYPOINT ["/usr/sbin/cgexec", "-g", "*:neon-postgres", "/usr/local/bin/compute_ctl"]
|
||||
|
||||
16
Makefile
16
Makefile
@@ -39,6 +39,8 @@ endif
|
||||
# been no changes to the files. Changing the mtime triggers an
|
||||
# unnecessary rebuild of 'postgres_ffi'.
|
||||
PG_CONFIGURE_OPTS += INSTALL='$(ROOT_PROJECT_DIR)/scripts/ninstall.sh -C'
|
||||
PG_CONFIGURE_OPTS += CC=clang
|
||||
PG_CONFIGURE_OPTS += CCX=clang++
|
||||
|
||||
# Choose whether we should be silent or verbose
|
||||
CARGO_BUILD_FLAGS += --$(if $(filter s,$(MAKEFLAGS)),quiet,verbose)
|
||||
@@ -133,11 +135,12 @@ neon-pg-ext-%: postgres-%
|
||||
$(MAKE) PG_CONFIG=$(POSTGRES_INSTALL_DIR)/$*/bin/pg_config CFLAGS='$(PG_CFLAGS) $(COPT)' \
|
||||
-C $(POSTGRES_INSTALL_DIR)/build/neon-test-utils-$* \
|
||||
-f $(ROOT_PROJECT_DIR)/pgxn/neon_test_utils/Makefile install
|
||||
+@echo "Compiling neon_utils $*"
|
||||
mkdir -p $(POSTGRES_INSTALL_DIR)/build/neon-utils-$*
|
||||
$(MAKE) PG_CONFIG=$(POSTGRES_INSTALL_DIR)/$*/bin/pg_config CFLAGS='$(PG_CFLAGS) $(COPT)' \
|
||||
-C $(POSTGRES_INSTALL_DIR)/build/neon-utils-$* \
|
||||
-f $(ROOT_PROJECT_DIR)/pgxn/neon_utils/Makefile install
|
||||
|
||||
.PHONY:
|
||||
neon-pg-ext-walproposer:
|
||||
$(MAKE) PG_CONFIG=$(POSTGRES_INSTALL_DIR)/v15/bin/pg_config CFLAGS='$(PG_CFLAGS) $(COPT)' \
|
||||
-C $(POSTGRES_INSTALL_DIR)/build/neon-v15 \
|
||||
-f $(ROOT_PROJECT_DIR)/pgxn/neon/Makefile install
|
||||
|
||||
.PHONY: neon-pg-ext-clean-%
|
||||
neon-pg-ext-clean-%:
|
||||
@@ -150,9 +153,6 @@ neon-pg-ext-clean-%:
|
||||
$(MAKE) PG_CONFIG=$(POSTGRES_INSTALL_DIR)/$*/bin/pg_config \
|
||||
-C $(POSTGRES_INSTALL_DIR)/build/neon-test-utils-$* \
|
||||
-f $(ROOT_PROJECT_DIR)/pgxn/neon_test_utils/Makefile clean
|
||||
$(MAKE) PG_CONFIG=$(POSTGRES_INSTALL_DIR)/$*/bin/pg_config \
|
||||
-C $(POSTGRES_INSTALL_DIR)/build/neon-utils-$* \
|
||||
-f $(ROOT_PROJECT_DIR)/pgxn/neon_utils/Makefile clean
|
||||
|
||||
.PHONY: neon-pg-ext
|
||||
neon-pg-ext: \
|
||||
|
||||
@@ -11,7 +11,6 @@ clap.workspace = true
|
||||
futures.workspace = true
|
||||
hyper = { workspace = true, features = ["full"] }
|
||||
notify.workspace = true
|
||||
num_cpus.workspace = true
|
||||
opentelemetry.workspace = true
|
||||
postgres.workspace = true
|
||||
regex.workspace = true
|
||||
|
||||
@@ -25,7 +25,6 @@ use anyhow::{Context, Result};
|
||||
use chrono::{DateTime, Utc};
|
||||
use postgres::{Client, NoTls};
|
||||
use serde::{Serialize, Serializer};
|
||||
use tokio_postgres;
|
||||
use tracing::{info, instrument, warn};
|
||||
|
||||
use crate::checker::create_writability_check_data;
|
||||
@@ -285,7 +284,6 @@ impl ComputeNode {
|
||||
handle_role_deletions(self, &mut client)?;
|
||||
handle_grants(self, &mut client)?;
|
||||
create_writability_check_data(&mut client)?;
|
||||
handle_extensions(&self.spec, &mut client)?;
|
||||
|
||||
// 'Close' connection
|
||||
drop(client);
|
||||
@@ -402,43 +400,4 @@ impl ComputeNode {
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Select `pg_stat_statements` data and return it as a stringified JSON
|
||||
pub async fn collect_insights(&self) -> String {
|
||||
let mut result_rows: Vec<String> = Vec::new();
|
||||
let connect_result = tokio_postgres::connect(self.connstr.as_str(), NoTls).await;
|
||||
let (client, connection) = connect_result.unwrap();
|
||||
tokio::spawn(async move {
|
||||
if let Err(e) = connection.await {
|
||||
eprintln!("connection error: {}", e);
|
||||
}
|
||||
});
|
||||
let result = client
|
||||
.simple_query(
|
||||
"SELECT
|
||||
row_to_json(pg_stat_statements)
|
||||
FROM
|
||||
pg_stat_statements
|
||||
WHERE
|
||||
userid != 'cloud_admin'::regrole::oid
|
||||
ORDER BY
|
||||
(mean_exec_time + mean_plan_time) DESC
|
||||
LIMIT 100",
|
||||
)
|
||||
.await;
|
||||
|
||||
if let Ok(raw_rows) = result {
|
||||
for message in raw_rows.iter() {
|
||||
if let postgres::SimpleQueryMessage::Row(row) = message {
|
||||
if let Some(json) = row.get(0) {
|
||||
result_rows.push(json.to_string());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
format!("{{\"pg_stat_statements\": [{}]}}", result_rows.join(","))
|
||||
} else {
|
||||
"{{\"pg_stat_statements\": []}}".to_string()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -7,7 +7,6 @@ use crate::compute::ComputeNode;
|
||||
use anyhow::Result;
|
||||
use hyper::service::{make_service_fn, service_fn};
|
||||
use hyper::{Body, Method, Request, Response, Server, StatusCode};
|
||||
use num_cpus;
|
||||
use serde_json;
|
||||
use tracing::{error, info};
|
||||
use tracing_utils::http::OtelName;
|
||||
@@ -34,13 +33,6 @@ async fn routes(req: Request<Body>, compute: &Arc<ComputeNode>) -> Response<Body
|
||||
Response::new(Body::from(serde_json::to_string(&compute.metrics).unwrap()))
|
||||
}
|
||||
|
||||
// Collect Postgres current usage insights
|
||||
(&Method::GET, "/insights") => {
|
||||
info!("serving /insights GET request");
|
||||
let insights = compute.collect_insights().await;
|
||||
Response::new(Body::from(insights))
|
||||
}
|
||||
|
||||
(&Method::POST, "/check_writability") => {
|
||||
info!("serving /check_writability POST request");
|
||||
let res = crate::checker::check_writability(compute).await;
|
||||
@@ -50,17 +42,6 @@ async fn routes(req: Request<Body>, compute: &Arc<ComputeNode>) -> Response<Body
|
||||
}
|
||||
}
|
||||
|
||||
(&Method::GET, "/info") => {
|
||||
let num_cpus = num_cpus::get_physical();
|
||||
info!("serving /info GET request. num_cpus: {}", num_cpus);
|
||||
Response::new(Body::from(
|
||||
serde_json::json!({
|
||||
"num_cpus": num_cpus,
|
||||
})
|
||||
.to_string(),
|
||||
))
|
||||
}
|
||||
|
||||
// Return the `404 Not Found` for any other routes.
|
||||
_ => {
|
||||
let mut not_found = Response::new(Body::from("404 Not Found"));
|
||||
|
||||
@@ -10,12 +10,12 @@ paths:
|
||||
/status:
|
||||
get:
|
||||
tags:
|
||||
- Info
|
||||
- "info"
|
||||
summary: Get compute node internal status
|
||||
description: ""
|
||||
operationId: getComputeStatus
|
||||
responses:
|
||||
200:
|
||||
"200":
|
||||
description: ComputeState
|
||||
content:
|
||||
application/json:
|
||||
@@ -25,58 +25,27 @@ paths:
|
||||
/metrics.json:
|
||||
get:
|
||||
tags:
|
||||
- Info
|
||||
- "info"
|
||||
summary: Get compute node startup metrics in JSON format
|
||||
description: ""
|
||||
operationId: getComputeMetricsJSON
|
||||
responses:
|
||||
200:
|
||||
"200":
|
||||
description: ComputeMetrics
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
$ref: "#/components/schemas/ComputeMetrics"
|
||||
|
||||
/insights:
|
||||
get:
|
||||
tags:
|
||||
- Info
|
||||
summary: Get current compute insights in JSON format
|
||||
description: |
|
||||
Note, that this doesn't include any historical data
|
||||
operationId: getComputeInsights
|
||||
responses:
|
||||
200:
|
||||
description: Compute insights
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
$ref: "#/components/schemas/ComputeInsights"
|
||||
|
||||
/info:
|
||||
get:
|
||||
tags:
|
||||
- "info"
|
||||
summary: Get info about the compute Pod/VM
|
||||
description: ""
|
||||
operationId: getInfo
|
||||
responses:
|
||||
"200":
|
||||
description: Info
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
$ref: "#/components/schemas/Info"
|
||||
|
||||
/check_writability:
|
||||
post:
|
||||
tags:
|
||||
- Check
|
||||
- "check"
|
||||
summary: Check that we can write new data on this compute
|
||||
description: ""
|
||||
operationId: checkComputeWritability
|
||||
responses:
|
||||
200:
|
||||
"200":
|
||||
description: Check result
|
||||
content:
|
||||
text/plain:
|
||||
@@ -111,15 +80,6 @@ components:
|
||||
total_startup_ms:
|
||||
type: integer
|
||||
|
||||
Info:
|
||||
type: object
|
||||
description: Information about VM/Pod
|
||||
required:
|
||||
- num_cpus
|
||||
properties:
|
||||
num_cpus:
|
||||
type: integer
|
||||
|
||||
ComputeState:
|
||||
type: object
|
||||
required:
|
||||
@@ -136,15 +96,6 @@ components:
|
||||
type: string
|
||||
description: Text of the error during compute startup, if any
|
||||
|
||||
ComputeInsights:
|
||||
type: object
|
||||
properties:
|
||||
pg_stat_statements:
|
||||
description: Contains raw output from pg_stat_statements in JSON format
|
||||
type: array
|
||||
items:
|
||||
type: object
|
||||
|
||||
ComputeStatus:
|
||||
type: string
|
||||
enum:
|
||||
|
||||
@@ -47,23 +47,12 @@ pub struct GenericOption {
|
||||
/// declare a `trait` on it.
|
||||
pub type GenericOptions = Option<Vec<GenericOption>>;
|
||||
|
||||
/// Escape a string for including it in a SQL literal
|
||||
fn escape_literal(s: &str) -> String {
|
||||
s.replace('\'', "''").replace('\\', "\\\\")
|
||||
}
|
||||
|
||||
/// Escape a string so that it can be used in postgresql.conf.
|
||||
/// Same as escape_literal, currently.
|
||||
fn escape_conf_value(s: &str) -> String {
|
||||
s.replace('\'', "''").replace('\\', "\\\\")
|
||||
}
|
||||
|
||||
impl GenericOption {
|
||||
/// Represent `GenericOption` as SQL statement parameter.
|
||||
pub fn to_pg_option(&self) -> String {
|
||||
if let Some(val) = &self.value {
|
||||
match self.vartype.as_ref() {
|
||||
"string" => format!("{} '{}'", self.name, escape_literal(val)),
|
||||
"string" => format!("{} '{}'", self.name, val),
|
||||
_ => format!("{} {}", self.name, val),
|
||||
}
|
||||
} else {
|
||||
@@ -74,8 +63,6 @@ impl GenericOption {
|
||||
/// Represent `GenericOption` as configuration option.
|
||||
pub fn to_pg_setting(&self) -> String {
|
||||
if let Some(val) = &self.value {
|
||||
// TODO: check in the console DB that we don't have these settings
|
||||
// set for any non-deleted project and drop this override.
|
||||
let name = match self.name.as_str() {
|
||||
"safekeepers" => "neon.safekeepers",
|
||||
"wal_acceptor_reconnect" => "neon.safekeeper_reconnect_timeout",
|
||||
@@ -84,7 +71,7 @@ impl GenericOption {
|
||||
};
|
||||
|
||||
match self.vartype.as_ref() {
|
||||
"string" => format!("{} = '{}'", name, escape_conf_value(val)),
|
||||
"string" => format!("{} = '{}'", name, val),
|
||||
_ => format!("{} = {}", name, val),
|
||||
}
|
||||
} else {
|
||||
@@ -120,7 +107,6 @@ impl PgOptionsSerialize for GenericOptions {
|
||||
.map(|op| op.to_pg_setting())
|
||||
.collect::<Vec<String>>()
|
||||
.join("\n")
|
||||
+ "\n" // newline after last setting
|
||||
} else {
|
||||
"".to_string()
|
||||
}
|
||||
|
||||
@@ -515,18 +515,3 @@ pub fn handle_grants(node: &ComputeNode, client: &mut Client) -> Result<()> {
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Create required system extensions
|
||||
#[instrument(skip_all)]
|
||||
pub fn handle_extensions(spec: &ComputeSpec, client: &mut Client) -> Result<()> {
|
||||
if let Some(libs) = spec.cluster.settings.find("shared_preload_libraries") {
|
||||
if libs.contains("pg_stat_statements") {
|
||||
// Create extension only if this compute really needs it
|
||||
let query = "CREATE EXTENSION IF NOT EXISTS pg_stat_statements";
|
||||
info!("creating system extensions with query: {}", query);
|
||||
client.simple_query(query)?;
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
@@ -178,11 +178,6 @@
|
||||
"name": "neon.pageserver_connstring",
|
||||
"value": "host=127.0.0.1 port=6400",
|
||||
"vartype": "string"
|
||||
},
|
||||
{
|
||||
"name": "test.escaping",
|
||||
"value": "here's a backslash \\ and a quote ' and a double-quote \" hooray",
|
||||
"vartype": "string"
|
||||
}
|
||||
]
|
||||
},
|
||||
|
||||
@@ -28,30 +28,7 @@ mod pg_helpers_tests {
|
||||
|
||||
assert_eq!(
|
||||
spec.cluster.settings.as_pg_settings(),
|
||||
r#"fsync = off
|
||||
wal_level = replica
|
||||
hot_standby = on
|
||||
neon.safekeepers = '127.0.0.1:6502,127.0.0.1:6503,127.0.0.1:6501'
|
||||
wal_log_hints = on
|
||||
log_connections = on
|
||||
shared_buffers = 32768
|
||||
port = 55432
|
||||
max_connections = 100
|
||||
max_wal_senders = 10
|
||||
listen_addresses = '0.0.0.0'
|
||||
wal_sender_timeout = 0
|
||||
password_encryption = md5
|
||||
maintenance_work_mem = 65536
|
||||
max_parallel_workers = 8
|
||||
max_worker_processes = 8
|
||||
neon.tenant_id = 'b0554b632bd4d547a63b86c3630317e8'
|
||||
max_replication_slots = 10
|
||||
neon.timeline_id = '2414a61ffc94e428f14b5758fe308e13'
|
||||
shared_preload_libraries = 'neon'
|
||||
synchronous_standby_names = 'walproposer'
|
||||
neon.pageserver_connstring = 'host=127.0.0.1 port=6400'
|
||||
test.escaping = 'here''s a backslash \\ and a quote '' and a double-quote " hooray'
|
||||
"#
|
||||
"fsync = off\nwal_level = replica\nhot_standby = on\nneon.safekeepers = '127.0.0.1:6502,127.0.0.1:6503,127.0.0.1:6501'\nwal_log_hints = on\nlog_connections = on\nshared_buffers = 32768\nport = 55432\nmax_connections = 100\nmax_wal_senders = 10\nlisten_addresses = '0.0.0.0'\nwal_sender_timeout = 0\npassword_encryption = md5\nmaintenance_work_mem = 65536\nmax_parallel_workers = 8\nmax_worker_processes = 8\nneon.tenant_id = 'b0554b632bd4d547a63b86c3630317e8'\nmax_replication_slots = 10\nneon.timeline_id = '2414a61ffc94e428f14b5758fe308e13'\nshared_preload_libraries = 'neon'\nsynchronous_standby_names = 'walproposer'\nneon.pageserver_connstring = 'host=127.0.0.1 port=6400'"
|
||||
);
|
||||
}
|
||||
|
||||
|
||||
@@ -24,7 +24,6 @@ url.workspace = true
|
||||
# Note: Do not directly depend on pageserver or safekeeper; use pageserver_api or safekeeper_api
|
||||
# instead, so that recompile times are better.
|
||||
pageserver_api.workspace = true
|
||||
postgres_backend.workspace = true
|
||||
safekeeper_api.workspace = true
|
||||
postgres_connection.workspace = true
|
||||
storage_broker.workspace = true
|
||||
|
||||
@@ -2,8 +2,7 @@
|
||||
[pageserver]
|
||||
listen_pg_addr = '127.0.0.1:64000'
|
||||
listen_http_addr = '127.0.0.1:9898'
|
||||
pg_auth_type = 'Trust'
|
||||
http_auth_type = 'Trust'
|
||||
auth_type = 'Trust'
|
||||
|
||||
[[safekeepers]]
|
||||
id = 1
|
||||
|
||||
@@ -3,8 +3,7 @@
|
||||
[pageserver]
|
||||
listen_pg_addr = '127.0.0.1:64000'
|
||||
listen_http_addr = '127.0.0.1:9898'
|
||||
pg_auth_type = 'Trust'
|
||||
http_auth_type = 'Trust'
|
||||
auth_type = 'Trust'
|
||||
|
||||
[[safekeepers]]
|
||||
id = 1
|
||||
|
||||
@@ -17,7 +17,6 @@ use pageserver_api::{
|
||||
DEFAULT_HTTP_LISTEN_ADDR as DEFAULT_PAGESERVER_HTTP_ADDR,
|
||||
DEFAULT_PG_LISTEN_ADDR as DEFAULT_PAGESERVER_PG_ADDR,
|
||||
};
|
||||
use postgres_backend::AuthType;
|
||||
use safekeeper_api::{
|
||||
DEFAULT_HTTP_LISTEN_PORT as DEFAULT_SAFEKEEPER_HTTP_PORT,
|
||||
DEFAULT_PG_LISTEN_PORT as DEFAULT_SAFEKEEPER_PG_PORT,
|
||||
@@ -31,6 +30,7 @@ use utils::{
|
||||
auth::{Claims, Scope},
|
||||
id::{NodeId, TenantId, TenantTimelineId, TimelineId},
|
||||
lsn::Lsn,
|
||||
postgres_backend::AuthType,
|
||||
project_git_version,
|
||||
};
|
||||
|
||||
@@ -53,15 +53,14 @@ listen_addr = '{DEFAULT_BROKER_ADDR}'
|
||||
id = {DEFAULT_PAGESERVER_ID}
|
||||
listen_pg_addr = '{DEFAULT_PAGESERVER_PG_ADDR}'
|
||||
listen_http_addr = '{DEFAULT_PAGESERVER_HTTP_ADDR}'
|
||||
pg_auth_type = '{trust_auth}'
|
||||
http_auth_type = '{trust_auth}'
|
||||
auth_type = '{pageserver_auth_type}'
|
||||
|
||||
[[safekeepers]]
|
||||
id = {DEFAULT_SAFEKEEPER_ID}
|
||||
pg_port = {DEFAULT_SAFEKEEPER_PG_PORT}
|
||||
http_port = {DEFAULT_SAFEKEEPER_HTTP_PORT}
|
||||
"#,
|
||||
trust_auth = AuthType::Trust,
|
||||
pageserver_auth_type = AuthType::Trust,
|
||||
)
|
||||
}
|
||||
|
||||
@@ -628,7 +627,7 @@ fn handle_pg(pg_match: &ArgMatches, env: &local_env::LocalEnv) -> Result<()> {
|
||||
|
||||
let node = cplane.nodes.get(&(tenant_id, node_name.to_string()));
|
||||
|
||||
let auth_token = if matches!(env.pageserver.pg_auth_type, AuthType::NeonJWT) {
|
||||
let auth_token = if matches!(env.pageserver.auth_type, AuthType::NeonJWT) {
|
||||
let claims = Claims::new(Some(tenant_id), Scope::Tenant);
|
||||
|
||||
Some(env.generate_auth_token(&claims)?)
|
||||
|
||||
@@ -11,10 +11,10 @@ use std::sync::Arc;
|
||||
use std::time::Duration;
|
||||
|
||||
use anyhow::{Context, Result};
|
||||
use postgres_backend::AuthType;
|
||||
use utils::{
|
||||
id::{TenantId, TimelineId},
|
||||
lsn::Lsn,
|
||||
postgres_backend::AuthType,
|
||||
};
|
||||
|
||||
use crate::local_env::{LocalEnv, DEFAULT_PG_VERSION};
|
||||
@@ -97,7 +97,7 @@ impl ComputeControlPlane {
|
||||
});
|
||||
|
||||
node.create_pgdata()?;
|
||||
node.setup_pg_conf(self.env.pageserver.pg_auth_type)?;
|
||||
node.setup_pg_conf(self.env.pageserver.auth_type)?;
|
||||
|
||||
self.nodes
|
||||
.insert((tenant_id, node.name.clone()), Arc::clone(&node));
|
||||
|
||||
@@ -5,7 +5,6 @@
|
||||
|
||||
use anyhow::{bail, ensure, Context};
|
||||
|
||||
use postgres_backend::AuthType;
|
||||
use reqwest::Url;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use serde_with::{serde_as, DisplayFromStr};
|
||||
@@ -20,6 +19,7 @@ use std::process::{Command, Stdio};
|
||||
use utils::{
|
||||
auth::{encode_from_key_file, Claims, Scope},
|
||||
id::{NodeId, TenantId, TenantTimelineId, TimelineId},
|
||||
postgres_backend::AuthType,
|
||||
};
|
||||
|
||||
use crate::safekeeper::SafekeeperNode;
|
||||
@@ -110,14 +110,12 @@ impl NeonBroker {
|
||||
pub struct PageServerConf {
|
||||
// node id
|
||||
pub id: NodeId,
|
||||
|
||||
// Pageserver connection settings
|
||||
pub listen_pg_addr: String,
|
||||
pub listen_http_addr: String,
|
||||
|
||||
// auth type used for the PG and HTTP ports
|
||||
pub pg_auth_type: AuthType,
|
||||
pub http_auth_type: AuthType,
|
||||
// used to determine which auth type is used
|
||||
pub auth_type: AuthType,
|
||||
|
||||
// jwt auth token used for communication with pageserver
|
||||
pub auth_token: String,
|
||||
@@ -129,8 +127,7 @@ impl Default for PageServerConf {
|
||||
id: NodeId(0),
|
||||
listen_pg_addr: String::new(),
|
||||
listen_http_addr: String::new(),
|
||||
pg_auth_type: AuthType::Trust,
|
||||
http_auth_type: AuthType::Trust,
|
||||
auth_type: AuthType::Trust,
|
||||
auth_token: String::new(),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -11,7 +11,6 @@ use anyhow::{bail, Context};
|
||||
use pageserver_api::models::{
|
||||
TenantConfigRequest, TenantCreateRequest, TenantInfo, TimelineCreateRequest, TimelineInfo,
|
||||
};
|
||||
use postgres_backend::AuthType;
|
||||
use postgres_connection::{parse_host_port, PgConnectionConfig};
|
||||
use reqwest::blocking::{Client, RequestBuilder, Response};
|
||||
use reqwest::{IntoUrl, Method};
|
||||
@@ -21,6 +20,7 @@ use utils::{
|
||||
http::error::HttpErrorBody,
|
||||
id::{TenantId, TimelineId},
|
||||
lsn::Lsn,
|
||||
postgres_backend::AuthType,
|
||||
};
|
||||
|
||||
use crate::{background_process, local_env::LocalEnv};
|
||||
@@ -82,7 +82,7 @@ impl PageServerNode {
|
||||
let (host, port) = parse_host_port(&env.pageserver.listen_pg_addr)
|
||||
.expect("Unable to parse listen_pg_addr");
|
||||
let port = port.unwrap_or(5432);
|
||||
let password = if env.pageserver.pg_auth_type == AuthType::NeonJWT {
|
||||
let password = if env.pageserver.auth_type == AuthType::NeonJWT {
|
||||
Some(env.pageserver.auth_token.clone())
|
||||
} else {
|
||||
None
|
||||
@@ -106,32 +106,25 @@ impl PageServerNode {
|
||||
self.env.pg_distrib_dir_raw().display()
|
||||
);
|
||||
|
||||
let http_auth_type_param =
|
||||
format!("http_auth_type='{}'", self.env.pageserver.http_auth_type);
|
||||
let authg_type_param = format!("auth_type='{}'", self.env.pageserver.auth_type);
|
||||
let listen_http_addr_param = format!(
|
||||
"listen_http_addr='{}'",
|
||||
self.env.pageserver.listen_http_addr
|
||||
);
|
||||
|
||||
let pg_auth_type_param = format!("pg_auth_type='{}'", self.env.pageserver.pg_auth_type);
|
||||
let listen_pg_addr_param =
|
||||
format!("listen_pg_addr='{}'", self.env.pageserver.listen_pg_addr);
|
||||
|
||||
let broker_endpoint_param = format!("broker_endpoint='{}'", self.env.broker.client_url());
|
||||
|
||||
let mut overrides = vec![
|
||||
id,
|
||||
pg_distrib_dir_param,
|
||||
http_auth_type_param,
|
||||
pg_auth_type_param,
|
||||
authg_type_param,
|
||||
listen_http_addr_param,
|
||||
listen_pg_addr_param,
|
||||
broker_endpoint_param,
|
||||
];
|
||||
|
||||
if self.env.pageserver.http_auth_type != AuthType::Trust
|
||||
|| self.env.pageserver.pg_auth_type != AuthType::Trust
|
||||
{
|
||||
if self.env.pageserver.auth_type != AuthType::Trust {
|
||||
overrides.push("auth_validation_public_key_path='auth_public_key.pem'".to_owned());
|
||||
}
|
||||
overrides
|
||||
@@ -254,10 +247,7 @@ impl PageServerNode {
|
||||
}
|
||||
|
||||
fn pageserver_env_variables(&self) -> anyhow::Result<Vec<(String, String)>> {
|
||||
// FIXME: why is this tied to pageserver's auth type? Whether or not the safekeeper
|
||||
// needs a token, and how to generate that token, seems independent to whether
|
||||
// the pageserver requires a token in incoming requests.
|
||||
Ok(if self.env.pageserver.http_auth_type != AuthType::Trust {
|
||||
Ok(if self.env.pageserver.auth_type != AuthType::Trust {
|
||||
// Generate a token to connect from the pageserver to a safekeeper
|
||||
let token = self
|
||||
.env
|
||||
@@ -293,7 +283,7 @@ impl PageServerNode {
|
||||
|
||||
fn http_request<U: IntoUrl>(&self, method: Method, url: U) -> RequestBuilder {
|
||||
let mut builder = self.http_client.request(method, url);
|
||||
if self.env.pageserver.http_auth_type == AuthType::NeonJWT {
|
||||
if self.env.pageserver.auth_type == AuthType::NeonJWT {
|
||||
builder = builder.bearer_auth(&self.env.pageserver.auth_token)
|
||||
}
|
||||
builder
|
||||
|
||||
@@ -115,10 +115,6 @@ impl SafekeeperNode {
|
||||
let datadir = self.datadir_path();
|
||||
|
||||
let id_string = id.to_string();
|
||||
// TODO: add availability_zone to the config.
|
||||
// Right now we just specify any value here and use it to check metrics in tests.
|
||||
let availability_zone = format!("sk-{}", id_string);
|
||||
|
||||
let mut args = vec![
|
||||
"-D",
|
||||
datadir.to_str().with_context(|| {
|
||||
@@ -130,8 +126,6 @@ impl SafekeeperNode {
|
||||
&listen_pg,
|
||||
"--listen-http",
|
||||
&listen_http,
|
||||
"--availability-zone",
|
||||
&availability_zone,
|
||||
];
|
||||
if !self.conf.sync {
|
||||
args.push("--no-sync");
|
||||
|
||||
@@ -29,41 +29,6 @@ These components should not have access to the private key and may only get toke
|
||||
The key pair is generated once for an installation of compute/pageserver/safekeeper, e.g. by `neon_local init`.
|
||||
There is currently no way to rotate the key without bringing down all components.
|
||||
|
||||
### Token format
|
||||
|
||||
The JWT tokens in Neon use RSA as the algorithm. Example:
|
||||
|
||||
Header:
|
||||
|
||||
```
|
||||
{
|
||||
"alg": "RS512", # RS256, RS384, or RS512
|
||||
"typ": "JWT"
|
||||
}
|
||||
```
|
||||
|
||||
Payload:
|
||||
|
||||
```
|
||||
{
|
||||
"scope": "tenant", # "tenant", "pageserverapi", or "safekeeperdata"
|
||||
"tenant_id": "5204921ff44f09de8094a1390a6a50f6",
|
||||
}
|
||||
```
|
||||
|
||||
|
||||
Meanings of scope:
|
||||
|
||||
"tenant": Provides access to all data for a specific tenant
|
||||
|
||||
"pageserverapi": Provides blanket access to all tenants on the pageserver plus pageserver-wide APIs.
|
||||
Should only be used e.g. for status check/tenant creation/list.
|
||||
|
||||
"safekeeperdata": Provides blanket access to all data on the safekeeper plus safekeeper-wide APIs.
|
||||
Should only be used e.g. for status check.
|
||||
Currently also used for connection from any pageserver to any safekeeper.
|
||||
|
||||
|
||||
### CLI
|
||||
CLI generates a key pair during call to `neon_local init` with the following commands:
|
||||
|
||||
@@ -137,12 +102,10 @@ Each compute should present a token valid for the timeline's tenant.
|
||||
Pageserver also has HTTP API: some parts are per-tenant,
|
||||
some parts are server-wide, these are different scopes.
|
||||
|
||||
Authentication can be enabled separately for the HTTP mgmt API, and
|
||||
for the libpq connections from compute. The `http_auth_type` and
|
||||
`pg_auth_type` configuration variables in Pageserver's config may
|
||||
have one of these values:
|
||||
The `auth_type` configuration variable in Pageserver's config may have
|
||||
either of three values:
|
||||
|
||||
* `Trust` removes all authentication.
|
||||
* `Trust` removes all authentication. The outdated `MD5` value does likewise
|
||||
* `NeonJWT` enables JWT validation.
|
||||
Tokens are validated using the public key which lies in a PEM file
|
||||
specified in the `auth_validation_public_key_path` config.
|
||||
|
||||
@@ -129,12 +129,13 @@ Run `poetry shell` to activate the virtual environment.
|
||||
Alternatively, use `poetry run` to run a single command in the venv, e.g. `poetry run pytest`.
|
||||
|
||||
### Obligatory checks
|
||||
We force code formatting via `black`, `ruff`, and type hints via `mypy`.
|
||||
We force code formatting via `black`, `isort` and type hints via `mypy`.
|
||||
Run the following commands in the repository's root (next to `pyproject.toml`):
|
||||
|
||||
```bash
|
||||
poetry run isort . # Imports are reformatted
|
||||
poetry run black . # All code is reformatted
|
||||
poetry run ruff . # Python linter
|
||||
poetry run flake8 . # Python linter
|
||||
poetry run mypy . # Ensure there are no typing errors
|
||||
```
|
||||
|
||||
|
||||
@@ -1,26 +0,0 @@
|
||||
[package]
|
||||
name = "postgres_backend"
|
||||
version = "0.1.0"
|
||||
edition.workspace = true
|
||||
license.workspace = true
|
||||
|
||||
[dependencies]
|
||||
async-trait.workspace = true
|
||||
anyhow.workspace = true
|
||||
bytes.workspace = true
|
||||
futures.workspace = true
|
||||
rustls.workspace = true
|
||||
serde.workspace = true
|
||||
thiserror.workspace = true
|
||||
tokio.workspace = true
|
||||
tokio-rustls.workspace = true
|
||||
tracing.workspace = true
|
||||
|
||||
pq_proto.workspace = true
|
||||
workspace_hack.workspace = true
|
||||
|
||||
[dev-dependencies]
|
||||
once_cell.workspace = true
|
||||
rustls-pemfile.workspace = true
|
||||
tokio-postgres.workspace = true
|
||||
tokio-postgres-rustls.workspace = true
|
||||
@@ -1,931 +0,0 @@
|
||||
//! Server-side asynchronous Postgres connection, as limited as we need.
|
||||
//! To use, create PostgresBackend and run() it, passing the Handler
|
||||
//! implementation determining how to process the queries. Currently its API
|
||||
//! is rather narrow, but we can extend it once required.
|
||||
use anyhow::Context;
|
||||
use bytes::Bytes;
|
||||
use futures::pin_mut;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::io::ErrorKind;
|
||||
use std::net::SocketAddr;
|
||||
use std::pin::Pin;
|
||||
use std::sync::Arc;
|
||||
use std::task::{ready, Poll};
|
||||
use std::{fmt, io};
|
||||
use std::{future::Future, str::FromStr};
|
||||
use tokio::io::{AsyncRead, AsyncWrite};
|
||||
use tokio_rustls::TlsAcceptor;
|
||||
use tracing::{debug, error, info, trace};
|
||||
|
||||
use pq_proto::framed::{ConnectionError, Framed, FramedReader, FramedWriter};
|
||||
use pq_proto::{
|
||||
BeMessage, FeMessage, FeStartupPacket, ProtocolError, SQLSTATE_INTERNAL_ERROR,
|
||||
SQLSTATE_SUCCESSFUL_COMPLETION,
|
||||
};
|
||||
|
||||
/// An error, occurred during query processing:
|
||||
/// either during the connection ([`ConnectionError`]) or before/after it.
|
||||
#[derive(thiserror::Error, Debug)]
|
||||
pub enum QueryError {
|
||||
/// The connection was lost while processing the query.
|
||||
#[error(transparent)]
|
||||
Disconnected(#[from] ConnectionError),
|
||||
/// Some other error
|
||||
#[error(transparent)]
|
||||
Other(#[from] anyhow::Error),
|
||||
}
|
||||
|
||||
impl From<io::Error> for QueryError {
|
||||
fn from(e: io::Error) -> Self {
|
||||
Self::Disconnected(ConnectionError::Io(e))
|
||||
}
|
||||
}
|
||||
|
||||
impl QueryError {
|
||||
pub fn pg_error_code(&self) -> &'static [u8; 5] {
|
||||
match self {
|
||||
Self::Disconnected(_) => b"08006", // connection failure
|
||||
Self::Other(_) => SQLSTATE_INTERNAL_ERROR, // internal error
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub fn is_expected_io_error(e: &io::Error) -> bool {
|
||||
use io::ErrorKind::*;
|
||||
matches!(
|
||||
e.kind(),
|
||||
ConnectionRefused | ConnectionAborted | ConnectionReset
|
||||
)
|
||||
}
|
||||
|
||||
#[async_trait::async_trait]
|
||||
pub trait Handler<IO> {
|
||||
/// Handle single query.
|
||||
/// postgres_backend will issue ReadyForQuery after calling this (this
|
||||
/// might be not what we want after CopyData streaming, but currently we don't
|
||||
/// care). It will also flush out the output buffer.
|
||||
async fn process_query(
|
||||
&mut self,
|
||||
pgb: &mut PostgresBackend<IO>,
|
||||
query_string: &str,
|
||||
) -> Result<(), QueryError>;
|
||||
|
||||
/// Called on startup packet receival, allows to process params.
|
||||
///
|
||||
/// If Ok(false) is returned postgres_backend will skip auth -- that is needed for new users
|
||||
/// creation is the proxy code. That is quite hacky and ad-hoc solution, may be we could allow
|
||||
/// to override whole init logic in implementations.
|
||||
fn startup(
|
||||
&mut self,
|
||||
_pgb: &mut PostgresBackend<IO>,
|
||||
_sm: &FeStartupPacket,
|
||||
) -> Result<(), QueryError> {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Check auth jwt
|
||||
fn check_auth_jwt(
|
||||
&mut self,
|
||||
_pgb: &mut PostgresBackend<IO>,
|
||||
_jwt_response: &[u8],
|
||||
) -> Result<(), QueryError> {
|
||||
Err(QueryError::Other(anyhow::anyhow!("JWT auth failed")))
|
||||
}
|
||||
}
|
||||
|
||||
/// PostgresBackend protocol state.
|
||||
/// XXX: The order of the constructors matters.
|
||||
#[derive(Clone, Copy, PartialEq, Eq, PartialOrd)]
|
||||
pub enum ProtoState {
|
||||
/// Nothing happened yet.
|
||||
Initialization,
|
||||
/// Encryption handshake is done; waiting for encrypted Startup message.
|
||||
Encrypted,
|
||||
/// Waiting for password (auth token).
|
||||
Authentication,
|
||||
/// Performed handshake and auth, ReadyForQuery is issued.
|
||||
Established,
|
||||
Closed,
|
||||
}
|
||||
|
||||
#[derive(Clone, Copy)]
|
||||
pub enum ProcessMsgResult {
|
||||
Continue,
|
||||
Break,
|
||||
}
|
||||
|
||||
/// Either plain TCP stream or encrypted one, implementing AsyncRead + AsyncWrite.
|
||||
pub enum MaybeTlsStream<IO> {
|
||||
Unencrypted(IO),
|
||||
Tls(Box<tokio_rustls::server::TlsStream<IO>>),
|
||||
}
|
||||
|
||||
impl<IO: AsyncRead + AsyncWrite + Unpin> AsyncWrite for MaybeTlsStream<IO> {
|
||||
fn poll_write(
|
||||
self: Pin<&mut Self>,
|
||||
cx: &mut std::task::Context<'_>,
|
||||
buf: &[u8],
|
||||
) -> Poll<io::Result<usize>> {
|
||||
match self.get_mut() {
|
||||
Self::Unencrypted(stream) => Pin::new(stream).poll_write(cx, buf),
|
||||
Self::Tls(stream) => Pin::new(stream).poll_write(cx, buf),
|
||||
}
|
||||
}
|
||||
fn poll_flush(self: Pin<&mut Self>, cx: &mut std::task::Context<'_>) -> Poll<io::Result<()>> {
|
||||
match self.get_mut() {
|
||||
Self::Unencrypted(stream) => Pin::new(stream).poll_flush(cx),
|
||||
Self::Tls(stream) => Pin::new(stream).poll_flush(cx),
|
||||
}
|
||||
}
|
||||
fn poll_shutdown(
|
||||
self: Pin<&mut Self>,
|
||||
cx: &mut std::task::Context<'_>,
|
||||
) -> Poll<io::Result<()>> {
|
||||
match self.get_mut() {
|
||||
Self::Unencrypted(stream) => Pin::new(stream).poll_shutdown(cx),
|
||||
Self::Tls(stream) => Pin::new(stream).poll_shutdown(cx),
|
||||
}
|
||||
}
|
||||
}
|
||||
impl<IO: AsyncRead + AsyncWrite + Unpin> AsyncRead for MaybeTlsStream<IO> {
|
||||
fn poll_read(
|
||||
self: Pin<&mut Self>,
|
||||
cx: &mut std::task::Context<'_>,
|
||||
buf: &mut tokio::io::ReadBuf<'_>,
|
||||
) -> Poll<io::Result<()>> {
|
||||
match self.get_mut() {
|
||||
Self::Unencrypted(stream) => Pin::new(stream).poll_read(cx, buf),
|
||||
Self::Tls(stream) => Pin::new(stream).poll_read(cx, buf),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, PartialEq, Eq, Clone, Copy, Serialize, Deserialize)]
|
||||
pub enum AuthType {
|
||||
Trust,
|
||||
// This mimics postgres's AuthenticationCleartextPassword but instead of password expects JWT
|
||||
NeonJWT,
|
||||
}
|
||||
|
||||
impl FromStr for AuthType {
|
||||
type Err = anyhow::Error;
|
||||
|
||||
fn from_str(s: &str) -> Result<Self, Self::Err> {
|
||||
match s {
|
||||
"Trust" => Ok(Self::Trust),
|
||||
"NeonJWT" => Ok(Self::NeonJWT),
|
||||
_ => anyhow::bail!("invalid value \"{s}\" for auth type"),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl fmt::Display for AuthType {
|
||||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||
f.write_str(match self {
|
||||
AuthType::Trust => "Trust",
|
||||
AuthType::NeonJWT => "NeonJWT",
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
/// Either full duplex Framed or write only half; the latter is left in
|
||||
/// PostgresBackend after call to `split`. In principle we could always store a
|
||||
/// pair of splitted handles, but that would force to to pay splitting price
|
||||
/// (Arc and kinda mutex inside polling) for all uses (e.g. pageserver).
|
||||
enum MaybeWriteOnly<IO> {
|
||||
Full(Framed<MaybeTlsStream<IO>>),
|
||||
WriteOnly(FramedWriter<MaybeTlsStream<IO>>),
|
||||
Broken, // temporary value palmed off during the split
|
||||
}
|
||||
|
||||
impl<IO: AsyncRead + AsyncWrite + Unpin> MaybeWriteOnly<IO> {
|
||||
async fn read_startup_message(&mut self) -> Result<Option<FeStartupPacket>, ConnectionError> {
|
||||
match self {
|
||||
MaybeWriteOnly::Full(framed) => framed.read_startup_message().await,
|
||||
MaybeWriteOnly::WriteOnly(_) => {
|
||||
Err(io::Error::new(ErrorKind::Other, "reading from write only half").into())
|
||||
}
|
||||
MaybeWriteOnly::Broken => panic!("IO on invalid MaybeWriteOnly"),
|
||||
}
|
||||
}
|
||||
|
||||
async fn read_message(&mut self) -> Result<Option<FeMessage>, ConnectionError> {
|
||||
match self {
|
||||
MaybeWriteOnly::Full(framed) => framed.read_message().await,
|
||||
MaybeWriteOnly::WriteOnly(_) => {
|
||||
Err(io::Error::new(ErrorKind::Other, "reading from write only half").into())
|
||||
}
|
||||
MaybeWriteOnly::Broken => panic!("IO on invalid MaybeWriteOnly"),
|
||||
}
|
||||
}
|
||||
|
||||
fn write_message_noflush(&mut self, msg: &BeMessage<'_>) -> Result<(), ProtocolError> {
|
||||
match self {
|
||||
MaybeWriteOnly::Full(framed) => framed.write_message(msg),
|
||||
MaybeWriteOnly::WriteOnly(framed_writer) => framed_writer.write_message_noflush(msg),
|
||||
MaybeWriteOnly::Broken => panic!("IO on invalid MaybeWriteOnly"),
|
||||
}
|
||||
}
|
||||
|
||||
async fn flush(&mut self) -> io::Result<()> {
|
||||
match self {
|
||||
MaybeWriteOnly::Full(framed) => framed.flush().await,
|
||||
MaybeWriteOnly::WriteOnly(framed_writer) => framed_writer.flush().await,
|
||||
MaybeWriteOnly::Broken => panic!("IO on invalid MaybeWriteOnly"),
|
||||
}
|
||||
}
|
||||
|
||||
async fn shutdown(&mut self) -> io::Result<()> {
|
||||
match self {
|
||||
MaybeWriteOnly::Full(framed) => framed.shutdown().await,
|
||||
MaybeWriteOnly::WriteOnly(framed_writer) => framed_writer.shutdown().await,
|
||||
MaybeWriteOnly::Broken => panic!("IO on invalid MaybeWriteOnly"),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub struct PostgresBackend<IO> {
|
||||
framed: MaybeWriteOnly<IO>,
|
||||
|
||||
pub state: ProtoState,
|
||||
|
||||
auth_type: AuthType,
|
||||
|
||||
peer_addr: SocketAddr,
|
||||
pub tls_config: Option<Arc<rustls::ServerConfig>>,
|
||||
}
|
||||
|
||||
pub type PostgresBackendTCP = PostgresBackend<tokio::net::TcpStream>;
|
||||
|
||||
pub fn query_from_cstring(query_string: Bytes) -> Vec<u8> {
|
||||
let mut query_string = query_string.to_vec();
|
||||
if let Some(ch) = query_string.last() {
|
||||
if *ch == 0 {
|
||||
query_string.pop();
|
||||
}
|
||||
}
|
||||
query_string
|
||||
}
|
||||
|
||||
/// Cast a byte slice to a string slice, dropping null terminator if there's one.
|
||||
fn cstr_to_str(bytes: &[u8]) -> anyhow::Result<&str> {
|
||||
let without_null = bytes.strip_suffix(&[0]).unwrap_or(bytes);
|
||||
std::str::from_utf8(without_null).map_err(|e| e.into())
|
||||
}
|
||||
|
||||
impl PostgresBackend<tokio::net::TcpStream> {
|
||||
pub fn new(
|
||||
socket: tokio::net::TcpStream,
|
||||
auth_type: AuthType,
|
||||
tls_config: Option<Arc<rustls::ServerConfig>>,
|
||||
) -> io::Result<Self> {
|
||||
let peer_addr = socket.peer_addr()?;
|
||||
let stream = MaybeTlsStream::Unencrypted(socket);
|
||||
|
||||
Ok(Self {
|
||||
framed: MaybeWriteOnly::Full(Framed::new(stream)),
|
||||
state: ProtoState::Initialization,
|
||||
auth_type,
|
||||
tls_config,
|
||||
peer_addr,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl<IO: AsyncRead + AsyncWrite + Unpin> PostgresBackend<IO> {
|
||||
pub fn new_from_io(
|
||||
socket: IO,
|
||||
peer_addr: SocketAddr,
|
||||
auth_type: AuthType,
|
||||
tls_config: Option<Arc<rustls::ServerConfig>>,
|
||||
) -> io::Result<Self> {
|
||||
let stream = MaybeTlsStream::Unencrypted(socket);
|
||||
|
||||
Ok(Self {
|
||||
framed: MaybeWriteOnly::Full(Framed::new(stream)),
|
||||
state: ProtoState::Initialization,
|
||||
auth_type,
|
||||
tls_config,
|
||||
peer_addr,
|
||||
})
|
||||
}
|
||||
|
||||
pub fn get_peer_addr(&self) -> &SocketAddr {
|
||||
&self.peer_addr
|
||||
}
|
||||
|
||||
/// Read full message or return None if connection is cleanly closed with no
|
||||
/// unprocessed data.
|
||||
pub async fn read_message(&mut self) -> Result<Option<FeMessage>, ConnectionError> {
|
||||
if let ProtoState::Closed = self.state {
|
||||
Ok(None)
|
||||
} else {
|
||||
let m = self.framed.read_message().await?;
|
||||
trace!("read msg {:?}", m);
|
||||
Ok(m)
|
||||
}
|
||||
}
|
||||
|
||||
/// Write message into internal output buffer, doesn't flush it. Technically
|
||||
/// error type can be only ProtocolError here (if, unlikely, serialization
|
||||
/// fails), but callers typically wrap it anyway.
|
||||
pub fn write_message_noflush(
|
||||
&mut self,
|
||||
message: &BeMessage<'_>,
|
||||
) -> Result<&mut Self, ConnectionError> {
|
||||
self.framed.write_message_noflush(message)?;
|
||||
trace!("wrote msg {:?}", message);
|
||||
Ok(self)
|
||||
}
|
||||
|
||||
/// Flush output buffer into the socket.
|
||||
pub async fn flush(&mut self) -> io::Result<()> {
|
||||
self.framed.flush().await
|
||||
}
|
||||
|
||||
/// Polling version of `flush()`, saves the caller need to pin.
|
||||
pub fn poll_flush(
|
||||
&mut self,
|
||||
cx: &mut std::task::Context<'_>,
|
||||
) -> Poll<Result<(), std::io::Error>> {
|
||||
let flush_fut = self.flush();
|
||||
pin_mut!(flush_fut);
|
||||
flush_fut.poll(cx)
|
||||
}
|
||||
|
||||
/// Write message into internal output buffer and flush it to the stream.
|
||||
pub async fn write_message(
|
||||
&mut self,
|
||||
message: &BeMessage<'_>,
|
||||
) -> Result<&mut Self, ConnectionError> {
|
||||
self.write_message_noflush(message)?;
|
||||
self.flush().await?;
|
||||
Ok(self)
|
||||
}
|
||||
|
||||
/// Returns an AsyncWrite implementation that wraps all the data written
|
||||
/// to it in CopyData messages, and writes them to the connection
|
||||
///
|
||||
/// The caller is responsible for sending CopyOutResponse and CopyDone messages.
|
||||
pub fn copyout_writer(&mut self) -> CopyDataWriter<IO> {
|
||||
CopyDataWriter { pgb: self }
|
||||
}
|
||||
|
||||
/// Wrapper for run_message_loop() that shuts down socket when we are done
|
||||
pub async fn run<F, S>(
|
||||
mut self,
|
||||
handler: &mut impl Handler<IO>,
|
||||
shutdown_watcher: F,
|
||||
) -> Result<(), QueryError>
|
||||
where
|
||||
F: Fn() -> S,
|
||||
S: Future,
|
||||
{
|
||||
let ret = self.run_message_loop(handler, shutdown_watcher).await;
|
||||
// socket might be already closed, e.g. if previously received error,
|
||||
// so ignore result.
|
||||
self.framed.shutdown().await.ok();
|
||||
ret
|
||||
}
|
||||
|
||||
async fn run_message_loop<F, S>(
|
||||
&mut self,
|
||||
handler: &mut impl Handler<IO>,
|
||||
shutdown_watcher: F,
|
||||
) -> Result<(), QueryError>
|
||||
where
|
||||
F: Fn() -> S,
|
||||
S: Future,
|
||||
{
|
||||
trace!("postgres backend to {:?} started", self.peer_addr);
|
||||
|
||||
tokio::select!(
|
||||
biased;
|
||||
|
||||
_ = shutdown_watcher() => {
|
||||
// We were requested to shut down.
|
||||
tracing::info!("shutdown request received during handshake");
|
||||
return Ok(())
|
||||
},
|
||||
|
||||
result = self.handshake(handler) => {
|
||||
// Handshake complete.
|
||||
result?;
|
||||
if self.state == ProtoState::Closed {
|
||||
return Ok(()); // EOF during handshake
|
||||
}
|
||||
}
|
||||
);
|
||||
|
||||
// Authentication completed
|
||||
let mut query_string = Bytes::new();
|
||||
while let Some(msg) = tokio::select!(
|
||||
biased;
|
||||
_ = shutdown_watcher() => {
|
||||
// We were requested to shut down.
|
||||
tracing::info!("shutdown request received in run_message_loop");
|
||||
Ok(None)
|
||||
},
|
||||
msg = self.read_message() => { msg },
|
||||
)? {
|
||||
trace!("got message {:?}", msg);
|
||||
|
||||
let result = self.process_message(handler, msg, &mut query_string).await;
|
||||
self.flush().await?;
|
||||
match result? {
|
||||
ProcessMsgResult::Continue => {
|
||||
self.flush().await?;
|
||||
continue;
|
||||
}
|
||||
ProcessMsgResult::Break => break,
|
||||
}
|
||||
}
|
||||
|
||||
trace!("postgres backend to {:?} exited", self.peer_addr);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Try to upgrade MaybeTlsStream into actual TLS one, performing handshake.
|
||||
async fn tls_upgrade(
|
||||
src: MaybeTlsStream<IO>,
|
||||
tls_config: Arc<rustls::ServerConfig>,
|
||||
) -> anyhow::Result<MaybeTlsStream<IO>> {
|
||||
match src {
|
||||
MaybeTlsStream::Unencrypted(s) => {
|
||||
let acceptor = TlsAcceptor::from(tls_config);
|
||||
let tls_stream = acceptor.accept(s).await?;
|
||||
Ok(MaybeTlsStream::Tls(Box::new(tls_stream)))
|
||||
}
|
||||
MaybeTlsStream::Tls(_) => {
|
||||
anyhow::bail!("TLS already started");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
async fn start_tls(&mut self) -> anyhow::Result<()> {
|
||||
// temporary replace stream with fake to cook TLS one, Indiana Jones style
|
||||
match std::mem::replace(&mut self.framed, MaybeWriteOnly::Broken) {
|
||||
MaybeWriteOnly::Full(framed) => {
|
||||
let tls_config = self
|
||||
.tls_config
|
||||
.as_ref()
|
||||
.context("start_tls called without conf")?
|
||||
.clone();
|
||||
let tls_framed = framed
|
||||
.map_stream(|s| PostgresBackend::tls_upgrade(s, tls_config))
|
||||
.await?;
|
||||
// push back ready TLS stream
|
||||
self.framed = MaybeWriteOnly::Full(tls_framed);
|
||||
Ok(())
|
||||
}
|
||||
MaybeWriteOnly::WriteOnly(_) => {
|
||||
anyhow::bail!("TLS upgrade attempt in split state")
|
||||
}
|
||||
MaybeWriteOnly::Broken => panic!("TLS upgrade on framed in invalid state"),
|
||||
}
|
||||
}
|
||||
|
||||
/// Split off owned read part from which messages can be read in different
|
||||
/// task/thread.
|
||||
pub fn split(&mut self) -> anyhow::Result<PostgresBackendReader<IO>> {
|
||||
// temporary replace stream with fake to cook split one, Indiana Jones style
|
||||
match std::mem::replace(&mut self.framed, MaybeWriteOnly::Broken) {
|
||||
MaybeWriteOnly::Full(framed) => {
|
||||
let (reader, writer) = framed.split();
|
||||
self.framed = MaybeWriteOnly::WriteOnly(writer);
|
||||
Ok(PostgresBackendReader(reader))
|
||||
}
|
||||
MaybeWriteOnly::WriteOnly(_) => {
|
||||
anyhow::bail!("PostgresBackend is already split")
|
||||
}
|
||||
MaybeWriteOnly::Broken => panic!("split on framed in invalid state"),
|
||||
}
|
||||
}
|
||||
|
||||
/// Join read part back.
|
||||
pub fn unsplit(&mut self, reader: PostgresBackendReader<IO>) -> anyhow::Result<()> {
|
||||
// temporary replace stream with fake to cook joined one, Indiana Jones style
|
||||
match std::mem::replace(&mut self.framed, MaybeWriteOnly::Broken) {
|
||||
MaybeWriteOnly::Full(_) => {
|
||||
anyhow::bail!("PostgresBackend is not split")
|
||||
}
|
||||
MaybeWriteOnly::WriteOnly(writer) => {
|
||||
let joined = Framed::unsplit(reader.0, writer);
|
||||
self.framed = MaybeWriteOnly::Full(joined);
|
||||
Ok(())
|
||||
}
|
||||
MaybeWriteOnly::Broken => panic!("unsplit on framed in invalid state"),
|
||||
}
|
||||
}
|
||||
|
||||
/// Perform handshake with the client, transitioning to Established.
|
||||
/// In case of EOF during handshake logs this, sets state to Closed and returns Ok(()).
|
||||
async fn handshake(&mut self, handler: &mut impl Handler<IO>) -> Result<(), QueryError> {
|
||||
while self.state < ProtoState::Authentication {
|
||||
match self.framed.read_startup_message().await? {
|
||||
Some(msg) => {
|
||||
self.process_startup_message(handler, msg).await?;
|
||||
}
|
||||
None => {
|
||||
trace!(
|
||||
"postgres backend to {:?} received EOF during handshake",
|
||||
self.peer_addr
|
||||
);
|
||||
self.state = ProtoState::Closed;
|
||||
return Ok(());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Perform auth, if needed.
|
||||
if self.state == ProtoState::Authentication {
|
||||
match self.framed.read_message().await? {
|
||||
Some(FeMessage::PasswordMessage(m)) => {
|
||||
assert!(self.auth_type == AuthType::NeonJWT);
|
||||
|
||||
let (_, jwt_response) = m.split_last().context("protocol violation")?;
|
||||
|
||||
if let Err(e) = handler.check_auth_jwt(self, jwt_response) {
|
||||
self.write_message_noflush(&BeMessage::ErrorResponse(
|
||||
&e.to_string(),
|
||||
Some(e.pg_error_code()),
|
||||
))?;
|
||||
return Err(e);
|
||||
}
|
||||
|
||||
self.write_message_noflush(&BeMessage::AuthenticationOk)?
|
||||
.write_message_noflush(&BeMessage::CLIENT_ENCODING)?
|
||||
.write_message(&BeMessage::ReadyForQuery)
|
||||
.await?;
|
||||
self.state = ProtoState::Established;
|
||||
}
|
||||
Some(m) => {
|
||||
return Err(QueryError::Other(anyhow::anyhow!(
|
||||
"Unexpected message {:?} while waiting for handshake",
|
||||
m
|
||||
)));
|
||||
}
|
||||
None => {
|
||||
trace!(
|
||||
"postgres backend to {:?} received EOF during auth",
|
||||
self.peer_addr
|
||||
);
|
||||
self.state = ProtoState::Closed;
|
||||
return Ok(());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Process startup packet:
|
||||
/// - transition to Established if auth type is trust
|
||||
/// - transition to Authentication if auth type is NeonJWT.
|
||||
/// - or perform TLS handshake -- then need to call this again to receive
|
||||
/// actual startup packet.
|
||||
async fn process_startup_message(
|
||||
&mut self,
|
||||
handler: &mut impl Handler<IO>,
|
||||
msg: FeStartupPacket,
|
||||
) -> Result<(), QueryError> {
|
||||
assert!(self.state < ProtoState::Authentication);
|
||||
let have_tls = self.tls_config.is_some();
|
||||
match msg {
|
||||
FeStartupPacket::SslRequest => {
|
||||
debug!("SSL requested");
|
||||
|
||||
self.write_message(&BeMessage::EncryptionResponse(have_tls))
|
||||
.await?;
|
||||
|
||||
if have_tls {
|
||||
self.start_tls().await?;
|
||||
self.state = ProtoState::Encrypted;
|
||||
}
|
||||
}
|
||||
FeStartupPacket::GssEncRequest => {
|
||||
debug!("GSS requested");
|
||||
self.write_message(&BeMessage::EncryptionResponse(false))
|
||||
.await?;
|
||||
}
|
||||
FeStartupPacket::StartupMessage { .. } => {
|
||||
if have_tls && !matches!(self.state, ProtoState::Encrypted) {
|
||||
self.write_message(&BeMessage::ErrorResponse("must connect with TLS", None))
|
||||
.await?;
|
||||
return Err(QueryError::Other(anyhow::anyhow!(
|
||||
"client did not connect with TLS"
|
||||
)));
|
||||
}
|
||||
|
||||
// NB: startup() may change self.auth_type -- we are using that in proxy code
|
||||
// to bypass auth for new users.
|
||||
handler.startup(self, &msg)?;
|
||||
|
||||
match self.auth_type {
|
||||
AuthType::Trust => {
|
||||
self.write_message_noflush(&BeMessage::AuthenticationOk)?
|
||||
.write_message_noflush(&BeMessage::CLIENT_ENCODING)?
|
||||
.write_message_noflush(&BeMessage::INTEGER_DATETIMES)?
|
||||
// The async python driver requires a valid server_version
|
||||
.write_message_noflush(&BeMessage::server_version("14.1"))?
|
||||
.write_message(&BeMessage::ReadyForQuery)
|
||||
.await?;
|
||||
self.state = ProtoState::Established;
|
||||
}
|
||||
AuthType::NeonJWT => {
|
||||
self.write_message(&BeMessage::AuthenticationCleartextPassword)
|
||||
.await?;
|
||||
self.state = ProtoState::Authentication;
|
||||
}
|
||||
}
|
||||
}
|
||||
FeStartupPacket::CancelRequest { .. } => {
|
||||
return Err(QueryError::Other(anyhow::anyhow!(
|
||||
"Unexpected CancelRequest message during handshake"
|
||||
)));
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn process_message(
|
||||
&mut self,
|
||||
handler: &mut impl Handler<IO>,
|
||||
msg: FeMessage,
|
||||
unnamed_query_string: &mut Bytes,
|
||||
) -> Result<ProcessMsgResult, QueryError> {
|
||||
// Allow only startup and password messages during auth. Otherwise client would be able to bypass auth
|
||||
// TODO: change that to proper top-level match of protocol state with separate message handling for each state
|
||||
assert!(self.state == ProtoState::Established);
|
||||
|
||||
match msg {
|
||||
FeMessage::Query(body) => {
|
||||
// remove null terminator
|
||||
let query_string = cstr_to_str(&body)?;
|
||||
|
||||
trace!("got query {query_string:?}");
|
||||
if let Err(e) = handler.process_query(self, query_string).await {
|
||||
log_query_error(query_string, &e);
|
||||
let short_error = short_error(&e);
|
||||
self.write_message_noflush(&BeMessage::ErrorResponse(
|
||||
&short_error,
|
||||
Some(e.pg_error_code()),
|
||||
))?;
|
||||
}
|
||||
self.write_message_noflush(&BeMessage::ReadyForQuery)?;
|
||||
}
|
||||
|
||||
FeMessage::Parse(m) => {
|
||||
*unnamed_query_string = m.query_string;
|
||||
self.write_message_noflush(&BeMessage::ParseComplete)?;
|
||||
}
|
||||
|
||||
FeMessage::Describe(_) => {
|
||||
self.write_message_noflush(&BeMessage::ParameterDescription)?
|
||||
.write_message_noflush(&BeMessage::NoData)?;
|
||||
}
|
||||
|
||||
FeMessage::Bind(_) => {
|
||||
self.write_message_noflush(&BeMessage::BindComplete)?;
|
||||
}
|
||||
|
||||
FeMessage::Close(_) => {
|
||||
self.write_message_noflush(&BeMessage::CloseComplete)?;
|
||||
}
|
||||
|
||||
FeMessage::Execute(_) => {
|
||||
let query_string = cstr_to_str(unnamed_query_string)?;
|
||||
trace!("got execute {query_string:?}");
|
||||
if let Err(e) = handler.process_query(self, query_string).await {
|
||||
log_query_error(query_string, &e);
|
||||
self.write_message_noflush(&BeMessage::ErrorResponse(
|
||||
&e.to_string(),
|
||||
Some(e.pg_error_code()),
|
||||
))?;
|
||||
}
|
||||
// NOTE there is no ReadyForQuery message. This handler is used
|
||||
// for basebackup and it uses CopyOut which doesn't require
|
||||
// ReadyForQuery message and backend just switches back to
|
||||
// processing mode after sending CopyDone or ErrorResponse.
|
||||
}
|
||||
|
||||
FeMessage::Sync => {
|
||||
self.write_message_noflush(&BeMessage::ReadyForQuery)?;
|
||||
}
|
||||
|
||||
FeMessage::Terminate => {
|
||||
return Ok(ProcessMsgResult::Break);
|
||||
}
|
||||
|
||||
// We prefer explicit pattern matching to wildcards, because
|
||||
// this helps us spot the places where new variants are missing
|
||||
FeMessage::CopyData(_)
|
||||
| FeMessage::CopyDone
|
||||
| FeMessage::CopyFail
|
||||
| FeMessage::PasswordMessage(_) => {
|
||||
return Err(QueryError::Other(anyhow::anyhow!(
|
||||
"unexpected message type: {msg:?}",
|
||||
)));
|
||||
}
|
||||
}
|
||||
|
||||
Ok(ProcessMsgResult::Continue)
|
||||
}
|
||||
|
||||
/// Log as info/error result of handling COPY stream and send back
|
||||
/// ErrorResponse if that makes sense. Shutdown the stream if we got
|
||||
/// Terminate. TODO: transition into waiting for Sync msg if we initiate the
|
||||
/// close.
|
||||
pub async fn handle_copy_stream_end(&mut self, end: CopyStreamHandlerEnd) {
|
||||
use CopyStreamHandlerEnd::*;
|
||||
|
||||
let expected_end = match &end {
|
||||
ServerInitiated(_) | CopyDone | CopyFail | Terminate | EOF => true,
|
||||
CopyStreamHandlerEnd::Disconnected(ConnectionError::Io(io_error))
|
||||
if is_expected_io_error(io_error) =>
|
||||
{
|
||||
true
|
||||
}
|
||||
_ => false,
|
||||
};
|
||||
if expected_end {
|
||||
info!("terminated: {:#}", end);
|
||||
} else {
|
||||
error!("terminated: {:?}", end);
|
||||
}
|
||||
|
||||
// Note: no current usages ever send this
|
||||
if let CopyDone = &end {
|
||||
if let Err(e) = self.write_message(&BeMessage::CopyDone).await {
|
||||
error!("failed to send CopyDone: {}", e);
|
||||
}
|
||||
}
|
||||
|
||||
if let Terminate = &end {
|
||||
self.state = ProtoState::Closed;
|
||||
}
|
||||
|
||||
let err_to_send_and_errcode = match &end {
|
||||
ServerInitiated(_) => Some((end.to_string(), SQLSTATE_SUCCESSFUL_COMPLETION)),
|
||||
Other(_) => Some((end.to_string(), SQLSTATE_INTERNAL_ERROR)),
|
||||
// Note: CopyFail in duplex copy is somewhat unexpected (at least to
|
||||
// PG walsender; evidently and per my docs reading client should
|
||||
// finish it with CopyDone). It is not a problem to recover from it
|
||||
// finishing the stream in both directions like we do, but note that
|
||||
// sync rust-postgres client (which we don't use anymore) hangs if
|
||||
// socket is not closed here.
|
||||
// https://github.com/sfackler/rust-postgres/issues/755
|
||||
// https://github.com/neondatabase/neon/issues/935
|
||||
//
|
||||
// Currently, the version of tokio_postgres replication patch we use
|
||||
// sends this when it closes the stream (e.g. pageserver decided to
|
||||
// switch conn to another safekeeper and client gets dropped).
|
||||
// Moreover, seems like 'connection' task errors with 'unexpected
|
||||
// message from server' when it receives ErrorResponse (anything but
|
||||
// CopyData/CopyDone) back.
|
||||
CopyFail => Some((end.to_string(), SQLSTATE_SUCCESSFUL_COMPLETION)),
|
||||
_ => None,
|
||||
};
|
||||
if let Some((err, errcode)) = err_to_send_and_errcode {
|
||||
if let Err(ee) = self
|
||||
.write_message(&BeMessage::ErrorResponse(&err, Some(errcode)))
|
||||
.await
|
||||
{
|
||||
error!("failed to send ErrorResponse: {}", ee);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub struct PostgresBackendReader<IO>(FramedReader<MaybeTlsStream<IO>>);
|
||||
|
||||
impl<IO: AsyncRead + AsyncWrite + Unpin> PostgresBackendReader<IO> {
|
||||
/// Read full message or return None if connection is cleanly closed with no
|
||||
/// unprocessed data.
|
||||
pub async fn read_message(&mut self) -> Result<Option<FeMessage>, ConnectionError> {
|
||||
let m = self.0.read_message().await?;
|
||||
trace!("read msg {:?}", m);
|
||||
Ok(m)
|
||||
}
|
||||
|
||||
/// Get CopyData contents of the next message in COPY stream or error
|
||||
/// closing it. The error type is wider than actual errors which can happen
|
||||
/// here -- it includes 'Other' and 'ServerInitiated', but that's ok for
|
||||
/// current callers.
|
||||
pub async fn read_copy_message(&mut self) -> Result<Bytes, CopyStreamHandlerEnd> {
|
||||
match self.read_message().await? {
|
||||
Some(msg) => match msg {
|
||||
FeMessage::CopyData(m) => Ok(m),
|
||||
FeMessage::CopyDone => Err(CopyStreamHandlerEnd::CopyDone),
|
||||
FeMessage::CopyFail => Err(CopyStreamHandlerEnd::CopyFail),
|
||||
FeMessage::Terminate => Err(CopyStreamHandlerEnd::Terminate),
|
||||
_ => Err(CopyStreamHandlerEnd::from(ConnectionError::Protocol(
|
||||
ProtocolError::Protocol(format!("unexpected message in COPY stream {:?}", msg)),
|
||||
))),
|
||||
},
|
||||
None => Err(CopyStreamHandlerEnd::EOF),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
///
|
||||
/// A futures::AsyncWrite implementation that wraps all data written to it in CopyData
|
||||
/// messages.
|
||||
///
|
||||
|
||||
pub struct CopyDataWriter<'a, IO> {
|
||||
pgb: &'a mut PostgresBackend<IO>,
|
||||
}
|
||||
|
||||
impl<'a, IO: AsyncRead + AsyncWrite + Unpin> AsyncWrite for CopyDataWriter<'a, IO> {
|
||||
fn poll_write(
|
||||
self: Pin<&mut Self>,
|
||||
cx: &mut std::task::Context<'_>,
|
||||
buf: &[u8],
|
||||
) -> Poll<Result<usize, std::io::Error>> {
|
||||
let this = self.get_mut();
|
||||
|
||||
// It's not strictly required to flush between each message, but makes it easier
|
||||
// to view in wireshark, and usually the messages that the callers write are
|
||||
// decently-sized anyway.
|
||||
if let Err(err) = ready!(this.pgb.poll_flush(cx)) {
|
||||
return Poll::Ready(Err(err));
|
||||
}
|
||||
|
||||
// CopyData
|
||||
// XXX: if the input is large, we should split it into multiple messages.
|
||||
// Not sure what the threshold should be, but the ultimate hard limit is that
|
||||
// the length cannot exceed u32.
|
||||
this.pgb
|
||||
.write_message_noflush(&BeMessage::CopyData(buf))
|
||||
// write_message only writes to the buffer, so it can fail iff the
|
||||
// message is invaid, but CopyData can't be invalid.
|
||||
.map_err(|_| io::Error::new(ErrorKind::Other, "failed to serialize CopyData"))?;
|
||||
|
||||
Poll::Ready(Ok(buf.len()))
|
||||
}
|
||||
|
||||
fn poll_flush(
|
||||
self: Pin<&mut Self>,
|
||||
cx: &mut std::task::Context<'_>,
|
||||
) -> Poll<Result<(), std::io::Error>> {
|
||||
let this = self.get_mut();
|
||||
this.pgb.poll_flush(cx)
|
||||
}
|
||||
|
||||
fn poll_shutdown(
|
||||
self: Pin<&mut Self>,
|
||||
cx: &mut std::task::Context<'_>,
|
||||
) -> Poll<Result<(), std::io::Error>> {
|
||||
let this = self.get_mut();
|
||||
this.pgb.poll_flush(cx)
|
||||
}
|
||||
}
|
||||
|
||||
pub fn short_error(e: &QueryError) -> String {
|
||||
match e {
|
||||
QueryError::Disconnected(connection_error) => connection_error.to_string(),
|
||||
QueryError::Other(e) => format!("{e:#}"),
|
||||
}
|
||||
}
|
||||
|
||||
fn log_query_error(query: &str, e: &QueryError) {
|
||||
match e {
|
||||
QueryError::Disconnected(ConnectionError::Io(io_error)) => {
|
||||
if is_expected_io_error(io_error) {
|
||||
info!("query handler for '{query}' failed with expected io error: {io_error}");
|
||||
} else {
|
||||
error!("query handler for '{query}' failed with io error: {io_error}");
|
||||
}
|
||||
}
|
||||
QueryError::Disconnected(other_connection_error) => {
|
||||
error!("query handler for '{query}' failed with connection error: {other_connection_error:?}")
|
||||
}
|
||||
QueryError::Other(e) => {
|
||||
error!("query handler for '{query}' failed: {e:?}");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Something finishing handling of COPY stream, see handle_copy_stream_end.
|
||||
/// This is not always a real error, but it allows to use ? and thiserror impls.
|
||||
#[derive(thiserror::Error, Debug)]
|
||||
pub enum CopyStreamHandlerEnd {
|
||||
/// Handler initiates the end of streaming.
|
||||
#[error("{0}")]
|
||||
ServerInitiated(String),
|
||||
#[error("received CopyDone")]
|
||||
CopyDone,
|
||||
#[error("received CopyFail")]
|
||||
CopyFail,
|
||||
#[error("received Terminate")]
|
||||
Terminate,
|
||||
#[error("EOF on COPY stream")]
|
||||
EOF,
|
||||
/// The connection was lost
|
||||
#[error(transparent)]
|
||||
Disconnected(#[from] ConnectionError),
|
||||
/// Some other error
|
||||
#[error(transparent)]
|
||||
Other(#[from] anyhow::Error),
|
||||
}
|
||||
@@ -1,140 +0,0 @@
|
||||
/// Test postgres_backend_async with tokio_postgres
|
||||
use once_cell::sync::Lazy;
|
||||
use postgres_backend::{AuthType, Handler, PostgresBackend, QueryError};
|
||||
use pq_proto::{BeMessage, RowDescriptor};
|
||||
use std::io::Cursor;
|
||||
use std::{future, sync::Arc};
|
||||
use tokio::io::{AsyncRead, AsyncWrite};
|
||||
use tokio::net::{TcpListener, TcpStream};
|
||||
use tokio_postgres::config::SslMode;
|
||||
use tokio_postgres::tls::MakeTlsConnect;
|
||||
use tokio_postgres::{Config, NoTls, SimpleQueryMessage};
|
||||
use tokio_postgres_rustls::MakeRustlsConnect;
|
||||
|
||||
// generate client, server test streams
|
||||
async fn make_tcp_pair() -> (TcpStream, TcpStream) {
|
||||
let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
|
||||
let addr = listener.local_addr().unwrap();
|
||||
let client_stream = TcpStream::connect(addr).await.unwrap();
|
||||
let (server_stream, _) = listener.accept().await.unwrap();
|
||||
(client_stream, server_stream)
|
||||
}
|
||||
|
||||
struct TestHandler {}
|
||||
|
||||
#[async_trait::async_trait]
|
||||
impl<IO: AsyncRead + AsyncWrite + Unpin + Send> Handler<IO> for TestHandler {
|
||||
// return single col 'hey' for any query
|
||||
async fn process_query(
|
||||
&mut self,
|
||||
pgb: &mut PostgresBackend<IO>,
|
||||
_query_string: &str,
|
||||
) -> Result<(), QueryError> {
|
||||
pgb.write_message_noflush(&BeMessage::RowDescription(&[RowDescriptor::text_col(
|
||||
b"hey",
|
||||
)]))?
|
||||
.write_message_noflush(&BeMessage::DataRow(&[Some("hey".as_bytes())]))?
|
||||
.write_message_noflush(&BeMessage::CommandComplete(b"SELECT 1"))?;
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
// test that basic select works
|
||||
#[tokio::test]
|
||||
async fn simple_select() {
|
||||
let (client_sock, server_sock) = make_tcp_pair().await;
|
||||
|
||||
// create and run pgbackend
|
||||
let pgbackend =
|
||||
PostgresBackend::new(server_sock, AuthType::Trust, None).expect("pgbackend creation");
|
||||
|
||||
tokio::spawn(async move {
|
||||
let mut handler = TestHandler {};
|
||||
pgbackend.run(&mut handler, future::pending::<()>).await
|
||||
});
|
||||
|
||||
let conf = Config::new();
|
||||
let (client, connection) = conf.connect_raw(client_sock, NoTls).await.expect("connect");
|
||||
// The connection object performs the actual communication with the database,
|
||||
// so spawn it off to run on its own.
|
||||
tokio::spawn(async move {
|
||||
if let Err(e) = connection.await {
|
||||
eprintln!("connection error: {}", e);
|
||||
}
|
||||
});
|
||||
|
||||
let first_val = &(client.simple_query("SELECT 42;").await.expect("select"))[0];
|
||||
if let SimpleQueryMessage::Row(row) = first_val {
|
||||
let first_col = row.get(0).expect("first column");
|
||||
assert_eq!(first_col, "hey");
|
||||
} else {
|
||||
panic!("expected SimpleQueryMessage::Row");
|
||||
}
|
||||
}
|
||||
|
||||
static KEY: Lazy<rustls::PrivateKey> = Lazy::new(|| {
|
||||
let mut cursor = Cursor::new(include_bytes!("key.pem"));
|
||||
rustls::PrivateKey(rustls_pemfile::rsa_private_keys(&mut cursor).unwrap()[0].clone())
|
||||
});
|
||||
|
||||
static CERT: Lazy<rustls::Certificate> = Lazy::new(|| {
|
||||
let mut cursor = Cursor::new(include_bytes!("cert.pem"));
|
||||
rustls::Certificate(rustls_pemfile::certs(&mut cursor).unwrap()[0].clone())
|
||||
});
|
||||
|
||||
// test that basic select with ssl works
|
||||
#[tokio::test]
|
||||
async fn simple_select_ssl() {
|
||||
let (client_sock, server_sock) = make_tcp_pair().await;
|
||||
|
||||
let server_cfg = rustls::ServerConfig::builder()
|
||||
.with_safe_defaults()
|
||||
.with_no_client_auth()
|
||||
.with_single_cert(vec![CERT.clone()], KEY.clone())
|
||||
.unwrap();
|
||||
let tls_config = Some(Arc::new(server_cfg));
|
||||
let pgbackend =
|
||||
PostgresBackend::new(server_sock, AuthType::Trust, tls_config).expect("pgbackend creation");
|
||||
|
||||
tokio::spawn(async move {
|
||||
let mut handler = TestHandler {};
|
||||
pgbackend.run(&mut handler, future::pending::<()>).await
|
||||
});
|
||||
|
||||
let client_cfg = rustls::ClientConfig::builder()
|
||||
.with_safe_defaults()
|
||||
.with_root_certificates({
|
||||
let mut store = rustls::RootCertStore::empty();
|
||||
store.add(&CERT).unwrap();
|
||||
store
|
||||
})
|
||||
.with_no_client_auth();
|
||||
let mut make_tls_connect = tokio_postgres_rustls::MakeRustlsConnect::new(client_cfg);
|
||||
let tls_connect = <MakeRustlsConnect as MakeTlsConnect<TcpStream>>::make_tls_connect(
|
||||
&mut make_tls_connect,
|
||||
"localhost",
|
||||
)
|
||||
.expect("make_tls_connect");
|
||||
|
||||
let mut conf = Config::new();
|
||||
conf.ssl_mode(SslMode::Require);
|
||||
let (client, connection) = conf
|
||||
.connect_raw(client_sock, tls_connect)
|
||||
.await
|
||||
.expect("connect");
|
||||
// The connection object performs the actual communication with the database,
|
||||
// so spawn it off to run on its own.
|
||||
tokio::spawn(async move {
|
||||
if let Err(e) = connection.await {
|
||||
eprintln!("connection error: {}", e);
|
||||
}
|
||||
});
|
||||
|
||||
let first_val = &(client.simple_query("SELECT 42;").await.expect("select"))[0];
|
||||
if let SimpleQueryMessage::Row(row) = first_val {
|
||||
let first_col = row.get(0).expect("first column");
|
||||
assert_eq!(first_col, "hey");
|
||||
} else {
|
||||
panic!("expected SimpleQueryMessage::Row");
|
||||
}
|
||||
}
|
||||
@@ -63,7 +63,10 @@ fn main() -> anyhow::Result<()> {
|
||||
pg_install_dir_versioned = cwd.join("..").join("..").join(pg_install_dir_versioned);
|
||||
}
|
||||
|
||||
let pg_config_bin = pg_install_dir_versioned.join("bin").join("pg_config");
|
||||
let pg_config_bin = pg_install_dir_versioned
|
||||
.join(pg_version)
|
||||
.join("bin")
|
||||
.join("pg_config");
|
||||
let inc_server_path: String = if pg_config_bin.exists() {
|
||||
let output = Command::new(pg_config_bin)
|
||||
.arg("--includedir-server")
|
||||
|
||||
@@ -5,8 +5,8 @@ edition.workspace = true
|
||||
license.workspace = true
|
||||
|
||||
[dependencies]
|
||||
anyhow.workspace = true
|
||||
bytes.workspace = true
|
||||
byteorder.workspace = true
|
||||
pin-project-lite.workspace = true
|
||||
postgres-protocol.workspace = true
|
||||
rand.workspace = true
|
||||
|
||||
@@ -1,244 +0,0 @@
|
||||
//! Provides `Framed` -- writing/flushing and reading Postgres messages to/from
|
||||
//! the async stream based on (and buffered with) BytesMut. All functions are
|
||||
//! cancellation safe.
|
||||
//!
|
||||
//! It is similar to what tokio_util::codec::Framed with appropriate codec
|
||||
//! provides, but `FramedReader` and `FramedWriter` read/write parts can be used
|
||||
//! separately without using split from futures::stream::StreamExt (which
|
||||
//! allocates box[1] in polling internally). tokio::io::split is used for splitting
|
||||
//! instead. Plus we customize error messages more than a single type for all io
|
||||
//! calls.
|
||||
//!
|
||||
//! [1] https://docs.rs/futures-util/0.3.26/src/futures_util/lock/bilock.rs.html#107
|
||||
use bytes::{Buf, BytesMut};
|
||||
use std::{
|
||||
future::Future,
|
||||
io::{self, ErrorKind},
|
||||
};
|
||||
use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt, ReadHalf, WriteHalf};
|
||||
|
||||
use crate::{BeMessage, FeMessage, FeStartupPacket, ProtocolError};
|
||||
|
||||
const INITIAL_CAPACITY: usize = 8 * 1024;
|
||||
|
||||
/// Error on postgres connection: either IO (physical transport error) or
|
||||
/// protocol violation.
|
||||
#[derive(thiserror::Error, Debug)]
|
||||
pub enum ConnectionError {
|
||||
#[error(transparent)]
|
||||
Io(#[from] io::Error),
|
||||
#[error(transparent)]
|
||||
Protocol(#[from] ProtocolError),
|
||||
}
|
||||
|
||||
impl ConnectionError {
|
||||
/// Proxy stream.rs uses only io::Error; provide it.
|
||||
pub fn into_io_error(self) -> io::Error {
|
||||
match self {
|
||||
ConnectionError::Io(io) => io,
|
||||
ConnectionError::Protocol(pe) => io::Error::new(io::ErrorKind::Other, pe.to_string()),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Wraps async io `stream`, providing messages to write/flush + read Postgres
|
||||
/// messages.
|
||||
pub struct Framed<S> {
|
||||
stream: S,
|
||||
read_buf: BytesMut,
|
||||
write_buf: BytesMut,
|
||||
}
|
||||
|
||||
impl<S> Framed<S> {
|
||||
pub fn new(stream: S) -> Self {
|
||||
Self {
|
||||
stream,
|
||||
read_buf: BytesMut::with_capacity(INITIAL_CAPACITY),
|
||||
write_buf: BytesMut::with_capacity(INITIAL_CAPACITY),
|
||||
}
|
||||
}
|
||||
|
||||
/// Get a shared reference to the underlying stream.
|
||||
pub fn get_ref(&self) -> &S {
|
||||
&self.stream
|
||||
}
|
||||
|
||||
/// Deconstruct into the underlying stream and read buffer.
|
||||
pub fn into_inner(self) -> (S, BytesMut) {
|
||||
(self.stream, self.read_buf)
|
||||
}
|
||||
|
||||
/// Return new Framed with stream type transformed by async f, for TLS
|
||||
/// upgrade.
|
||||
pub async fn map_stream<S2, E, F, Fut>(self, f: F) -> Result<Framed<S2>, E>
|
||||
where
|
||||
F: FnOnce(S) -> Fut,
|
||||
Fut: Future<Output = Result<S2, E>>,
|
||||
{
|
||||
let stream = f(self.stream).await?;
|
||||
Ok(Framed {
|
||||
stream,
|
||||
read_buf: self.read_buf,
|
||||
write_buf: self.write_buf,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl<S: AsyncRead + Unpin> Framed<S> {
|
||||
pub async fn read_startup_message(
|
||||
&mut self,
|
||||
) -> Result<Option<FeStartupPacket>, ConnectionError> {
|
||||
read_message(&mut self.stream, &mut self.read_buf, FeStartupPacket::parse).await
|
||||
}
|
||||
|
||||
pub async fn read_message(&mut self) -> Result<Option<FeMessage>, ConnectionError> {
|
||||
read_message(&mut self.stream, &mut self.read_buf, FeMessage::parse).await
|
||||
}
|
||||
}
|
||||
|
||||
impl<S: AsyncWrite + Unpin> Framed<S> {
|
||||
/// Write next message to the output buffer; doesn't flush.
|
||||
pub fn write_message(&mut self, msg: &BeMessage<'_>) -> Result<(), ProtocolError> {
|
||||
BeMessage::write(&mut self.write_buf, msg)
|
||||
}
|
||||
|
||||
/// Flush out the buffer. This function is cancellation safe: it can be
|
||||
/// interrupted and flushing will be continued in the next call.
|
||||
pub async fn flush(&mut self) -> Result<(), io::Error> {
|
||||
flush(&mut self.stream, &mut self.write_buf).await
|
||||
}
|
||||
|
||||
/// Flush out the buffer and shutdown the stream.
|
||||
pub async fn shutdown(&mut self) -> Result<(), io::Error> {
|
||||
shutdown(&mut self.stream, &mut self.write_buf).await
|
||||
}
|
||||
}
|
||||
|
||||
impl<S: AsyncRead + AsyncWrite + Unpin> Framed<S> {
|
||||
/// Split into owned read and write parts. Beware of potential issues with
|
||||
/// using halves in different tasks on TLS stream:
|
||||
/// https://github.com/tokio-rs/tls/issues/40
|
||||
pub fn split(self) -> (FramedReader<S>, FramedWriter<S>) {
|
||||
let (read_half, write_half) = tokio::io::split(self.stream);
|
||||
let reader = FramedReader {
|
||||
stream: read_half,
|
||||
read_buf: self.read_buf,
|
||||
};
|
||||
let writer = FramedWriter {
|
||||
stream: write_half,
|
||||
write_buf: self.write_buf,
|
||||
};
|
||||
(reader, writer)
|
||||
}
|
||||
|
||||
/// Join read and write parts back.
|
||||
pub fn unsplit(reader: FramedReader<S>, writer: FramedWriter<S>) -> Self {
|
||||
Self {
|
||||
stream: reader.stream.unsplit(writer.stream),
|
||||
read_buf: reader.read_buf,
|
||||
write_buf: writer.write_buf,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Read-only version of `Framed`.
|
||||
pub struct FramedReader<S> {
|
||||
stream: ReadHalf<S>,
|
||||
read_buf: BytesMut,
|
||||
}
|
||||
|
||||
impl<S: AsyncRead + Unpin> FramedReader<S> {
|
||||
pub async fn read_message(&mut self) -> Result<Option<FeMessage>, ConnectionError> {
|
||||
read_message(&mut self.stream, &mut self.read_buf, FeMessage::parse).await
|
||||
}
|
||||
}
|
||||
|
||||
/// Write-only version of `Framed`.
|
||||
pub struct FramedWriter<S> {
|
||||
stream: WriteHalf<S>,
|
||||
write_buf: BytesMut,
|
||||
}
|
||||
|
||||
impl<S: AsyncWrite + Unpin> FramedWriter<S> {
|
||||
/// Write next message to the output buffer; doesn't flush.
|
||||
pub fn write_message_noflush(&mut self, msg: &BeMessage<'_>) -> Result<(), ProtocolError> {
|
||||
BeMessage::write(&mut self.write_buf, msg)
|
||||
}
|
||||
|
||||
/// Flush out the buffer. This function is cancellation safe: it can be
|
||||
/// interrupted and flushing will be continued in the next call.
|
||||
pub async fn flush(&mut self) -> Result<(), io::Error> {
|
||||
flush(&mut self.stream, &mut self.write_buf).await
|
||||
}
|
||||
|
||||
/// Flush out the buffer and shutdown the stream.
|
||||
pub async fn shutdown(&mut self) -> Result<(), io::Error> {
|
||||
shutdown(&mut self.stream, &mut self.write_buf).await
|
||||
}
|
||||
}
|
||||
|
||||
/// Read next message from the stream. Returns Ok(None), if EOF happened and we
|
||||
/// don't have remaining data in the buffer. This function is cancellation safe:
|
||||
/// you can drop future which is not yet complete and finalize reading message
|
||||
/// with the next call.
|
||||
///
|
||||
/// Parametrized to allow reading startup or usual message, having different
|
||||
/// format.
|
||||
async fn read_message<S: AsyncRead + Unpin, M, P>(
|
||||
stream: &mut S,
|
||||
read_buf: &mut BytesMut,
|
||||
parse: P,
|
||||
) -> Result<Option<M>, ConnectionError>
|
||||
where
|
||||
P: Fn(&mut BytesMut) -> Result<Option<M>, ProtocolError>,
|
||||
{
|
||||
loop {
|
||||
if let Some(msg) = parse(read_buf)? {
|
||||
return Ok(Some(msg));
|
||||
}
|
||||
// If we can't build a frame yet, try to read more data and try again.
|
||||
// Make sure we've got room for at least one byte to read to ensure
|
||||
// that we don't get a spurious 0 that looks like EOF.
|
||||
read_buf.reserve(1);
|
||||
if stream.read_buf(read_buf).await? == 0 {
|
||||
if read_buf.has_remaining() {
|
||||
return Err(io::Error::new(
|
||||
ErrorKind::UnexpectedEof,
|
||||
"EOF with unprocessed data in the buffer",
|
||||
)
|
||||
.into());
|
||||
} else {
|
||||
return Ok(None); // clean EOF
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
async fn flush<S: AsyncWrite + Unpin>(
|
||||
stream: &mut S,
|
||||
write_buf: &mut BytesMut,
|
||||
) -> Result<(), io::Error> {
|
||||
while write_buf.has_remaining() {
|
||||
let bytes_written = stream.write(write_buf.chunk()).await?;
|
||||
if bytes_written == 0 {
|
||||
return Err(io::Error::new(
|
||||
ErrorKind::WriteZero,
|
||||
"failed to write message",
|
||||
));
|
||||
}
|
||||
// The advanced part will be garbage collected, likely during shifting
|
||||
// data left on next attempt to write to buffer when free space is not
|
||||
// enough.
|
||||
write_buf.advance(bytes_written);
|
||||
}
|
||||
write_buf.clear();
|
||||
stream.flush().await
|
||||
}
|
||||
|
||||
async fn shutdown<S: AsyncWrite + Unpin>(
|
||||
stream: &mut S,
|
||||
write_buf: &mut BytesMut,
|
||||
) -> Result<(), io::Error> {
|
||||
flush(stream, write_buf).await?;
|
||||
stream.shutdown().await
|
||||
}
|
||||
@@ -2,18 +2,24 @@
|
||||
//! <https://www.postgresql.org/docs/devel/protocol-message-formats.html>
|
||||
//! on message formats.
|
||||
|
||||
pub mod framed;
|
||||
// Tools for calling certain async methods in sync contexts.
|
||||
pub mod sync;
|
||||
|
||||
use byteorder::{BigEndian, ReadBytesExt};
|
||||
use anyhow::{ensure, Context, Result};
|
||||
use bytes::{Buf, BufMut, Bytes, BytesMut};
|
||||
use postgres_protocol::PG_EPOCH;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::{
|
||||
borrow::Cow,
|
||||
collections::HashMap,
|
||||
fmt, io, str,
|
||||
fmt,
|
||||
future::Future,
|
||||
io::{self, Cursor},
|
||||
str,
|
||||
time::{Duration, SystemTime},
|
||||
};
|
||||
use sync::{AsyncishRead, SyncFuture};
|
||||
use tokio::io::AsyncReadExt;
|
||||
use tracing::{trace, warn};
|
||||
|
||||
pub type Oid = u32;
|
||||
@@ -25,6 +31,7 @@ pub const TEXT_OID: Oid = 25;
|
||||
|
||||
#[derive(Debug)]
|
||||
pub enum FeMessage {
|
||||
StartupPacket(FeStartupPacket),
|
||||
// Simple query.
|
||||
Query(Bytes),
|
||||
// Extended query protocol.
|
||||
@@ -184,205 +191,260 @@ pub struct FeExecuteMessage {
|
||||
#[derive(Debug)]
|
||||
pub struct FeCloseMessage;
|
||||
|
||||
/// An error occured while parsing or serializing raw stream into Postgres
|
||||
/// messages.
|
||||
#[derive(thiserror::Error, Debug)]
|
||||
pub enum ProtocolError {
|
||||
/// Invalid packet was received from the client (e.g. unexpected message
|
||||
/// type or broken len).
|
||||
#[error("Protocol error: {0}")]
|
||||
Protocol(String),
|
||||
/// Failed to parse or, (unlikely), serialize a protocol message.
|
||||
#[error("Message parse error: {0}")]
|
||||
BadMessage(String),
|
||||
/// Retry a read on EINTR
|
||||
///
|
||||
/// This runs the enclosed expression, and if it returns
|
||||
/// Err(io::ErrorKind::Interrupted), retries it.
|
||||
macro_rules! retry_read {
|
||||
( $x:expr ) => {
|
||||
loop {
|
||||
match $x {
|
||||
Err(e) if e.kind() == io::ErrorKind::Interrupted => continue,
|
||||
res => break res,
|
||||
}
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
impl ProtocolError {
|
||||
/// Proxy stream.rs uses only io::Error; provide it.
|
||||
/// An error occured during connection being open.
|
||||
#[derive(thiserror::Error, Debug)]
|
||||
pub enum ConnectionError {
|
||||
/// IO error during writing to or reading from the connection socket.
|
||||
#[error("Socket IO error: {0}")]
|
||||
Socket(std::io::Error),
|
||||
/// Invalid packet was received from client
|
||||
#[error("Protocol error: {0}")]
|
||||
Protocol(String),
|
||||
/// Failed to parse a protocol mesage
|
||||
#[error("Message parse error: {0}")]
|
||||
MessageParse(anyhow::Error),
|
||||
}
|
||||
|
||||
impl From<anyhow::Error> for ConnectionError {
|
||||
fn from(e: anyhow::Error) -> Self {
|
||||
Self::MessageParse(e)
|
||||
}
|
||||
}
|
||||
|
||||
impl ConnectionError {
|
||||
pub fn into_io_error(self) -> io::Error {
|
||||
io::Error::new(io::ErrorKind::Other, self.to_string())
|
||||
match self {
|
||||
ConnectionError::Socket(io) => io,
|
||||
other => io::Error::new(io::ErrorKind::Other, other.to_string()),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl FeMessage {
|
||||
/// Read and parse one message from the `buf` input buffer. If there is at
|
||||
/// least one valid message, returns it, advancing `buf`; redundant copies
|
||||
/// are avoided, as thanks to `bytes` crate ptrs in parsed message point
|
||||
/// directly into the `buf` (processed data is garbage collected after
|
||||
/// parsed message is dropped).
|
||||
/// Read one message from the stream.
|
||||
/// This function returns `Ok(None)` in case of EOF.
|
||||
/// One way to handle this properly:
|
||||
///
|
||||
/// Returns None if `buf` doesn't contain enough data for a single message.
|
||||
/// For efficiency, tries to reserve large enough space in `buf` for the
|
||||
/// next message in this case to save the repeated calls.
|
||||
/// ```
|
||||
/// # use std::io;
|
||||
/// # use pq_proto::FeMessage;
|
||||
/// #
|
||||
/// # fn process_message(msg: FeMessage) -> anyhow::Result<()> {
|
||||
/// # Ok(())
|
||||
/// # };
|
||||
/// #
|
||||
/// fn do_the_job(stream: &mut (impl io::Read + Unpin)) -> anyhow::Result<()> {
|
||||
/// while let Some(msg) = FeMessage::read(stream)? {
|
||||
/// process_message(msg)?;
|
||||
/// }
|
||||
///
|
||||
/// Returns Error if message is malformed, the only possible ErrorKind is
|
||||
/// InvalidInput.
|
||||
//
|
||||
// Inspired by rust-postgres Message::parse.
|
||||
pub fn parse(buf: &mut BytesMut) -> Result<Option<FeMessage>, ProtocolError> {
|
||||
// Every message contains message type byte and 4 bytes len; can't do
|
||||
// much without them.
|
||||
if buf.len() < 5 {
|
||||
let to_read = 5 - buf.len();
|
||||
buf.reserve(to_read);
|
||||
return Ok(None);
|
||||
}
|
||||
/// Ok(())
|
||||
/// }
|
||||
/// ```
|
||||
#[inline(never)]
|
||||
pub fn read(
|
||||
stream: &mut (impl io::Read + Unpin),
|
||||
) -> Result<Option<FeMessage>, ConnectionError> {
|
||||
Self::read_fut(&mut AsyncishRead(stream)).wait()
|
||||
}
|
||||
|
||||
// We shouldn't advance `buf` as probably full message is not there yet,
|
||||
// so can't directly use Bytes::get_u32 etc.
|
||||
let tag = buf[0];
|
||||
let len = (&buf[1..5]).read_u32::<BigEndian>().unwrap();
|
||||
if len < 4 {
|
||||
return Err(ProtocolError::Protocol(format!(
|
||||
"invalid message length {}",
|
||||
len
|
||||
)));
|
||||
}
|
||||
/// Read one message from the stream.
|
||||
/// See documentation for `Self::read`.
|
||||
pub fn read_fut<Reader>(
|
||||
stream: &mut Reader,
|
||||
) -> SyncFuture<Reader, impl Future<Output = Result<Option<FeMessage>, ConnectionError>> + '_>
|
||||
where
|
||||
Reader: tokio::io::AsyncRead + Unpin,
|
||||
{
|
||||
// We return a Future that's sync (has a `wait` method) if and only if the provided stream is SyncProof.
|
||||
// SyncFuture contract: we are only allowed to await on sync-proof futures, the AsyncRead and
|
||||
// AsyncReadExt methods of the stream.
|
||||
SyncFuture::new(async move {
|
||||
// Each libpq message begins with a message type byte, followed by message length
|
||||
// If the client closes the connection, return None. But if the client closes the
|
||||
// connection in the middle of a message, we will return an error.
|
||||
let tag = match retry_read!(stream.read_u8().await) {
|
||||
Ok(b) => b,
|
||||
Err(e) if e.kind() == io::ErrorKind::UnexpectedEof => return Ok(None),
|
||||
Err(e) => return Err(ConnectionError::Socket(e)),
|
||||
};
|
||||
|
||||
// length field includes itself, but not message type.
|
||||
let total_len = len as usize + 1;
|
||||
if buf.len() < total_len {
|
||||
// Don't have full message yet.
|
||||
let to_read = total_len - buf.len();
|
||||
buf.reserve(to_read);
|
||||
return Ok(None);
|
||||
}
|
||||
// The message length includes itself, so it better be at least 4.
|
||||
let len = retry_read!(stream.read_u32().await)
|
||||
.map_err(ConnectionError::Socket)?
|
||||
.checked_sub(4)
|
||||
.ok_or_else(|| ConnectionError::Protocol("invalid message length".to_string()))?;
|
||||
|
||||
// got the message, advance buffer
|
||||
let mut msg = buf.split_to(total_len).freeze();
|
||||
msg.advance(5); // consume message type and len
|
||||
let body = {
|
||||
let mut buffer = vec![0u8; len as usize];
|
||||
stream
|
||||
.read_exact(&mut buffer)
|
||||
.await
|
||||
.map_err(ConnectionError::Socket)?;
|
||||
Bytes::from(buffer)
|
||||
};
|
||||
|
||||
match tag {
|
||||
b'Q' => Ok(Some(FeMessage::Query(msg))),
|
||||
b'P' => Ok(Some(FeParseMessage::parse(msg)?)),
|
||||
b'D' => Ok(Some(FeDescribeMessage::parse(msg)?)),
|
||||
b'E' => Ok(Some(FeExecuteMessage::parse(msg)?)),
|
||||
b'B' => Ok(Some(FeBindMessage::parse(msg)?)),
|
||||
b'C' => Ok(Some(FeCloseMessage::parse(msg)?)),
|
||||
b'S' => Ok(Some(FeMessage::Sync)),
|
||||
b'X' => Ok(Some(FeMessage::Terminate)),
|
||||
b'd' => Ok(Some(FeMessage::CopyData(msg))),
|
||||
b'c' => Ok(Some(FeMessage::CopyDone)),
|
||||
b'f' => Ok(Some(FeMessage::CopyFail)),
|
||||
b'p' => Ok(Some(FeMessage::PasswordMessage(msg))),
|
||||
tag => Err(ProtocolError::Protocol(format!(
|
||||
"unknown message tag: {tag},'{msg:?}'"
|
||||
))),
|
||||
}
|
||||
match tag {
|
||||
b'Q' => Ok(Some(FeMessage::Query(body))),
|
||||
b'P' => Ok(Some(FeParseMessage::parse(body)?)),
|
||||
b'D' => Ok(Some(FeDescribeMessage::parse(body)?)),
|
||||
b'E' => Ok(Some(FeExecuteMessage::parse(body)?)),
|
||||
b'B' => Ok(Some(FeBindMessage::parse(body)?)),
|
||||
b'C' => Ok(Some(FeCloseMessage::parse(body)?)),
|
||||
b'S' => Ok(Some(FeMessage::Sync)),
|
||||
b'X' => Ok(Some(FeMessage::Terminate)),
|
||||
b'd' => Ok(Some(FeMessage::CopyData(body))),
|
||||
b'c' => Ok(Some(FeMessage::CopyDone)),
|
||||
b'f' => Ok(Some(FeMessage::CopyFail)),
|
||||
b'p' => Ok(Some(FeMessage::PasswordMessage(body))),
|
||||
tag => {
|
||||
return Err(ConnectionError::Protocol(format!(
|
||||
"unknown message tag: {tag},'{body:?}'"
|
||||
)))
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl FeStartupPacket {
|
||||
/// Read and parse startup message from the `buf` input buffer. It is
|
||||
/// different from [`FeMessage::parse`] because startup messages don't have
|
||||
/// message type byte; otherwise, its comments apply.
|
||||
pub fn parse(buf: &mut BytesMut) -> Result<Option<FeStartupPacket>, ProtocolError> {
|
||||
/// Read startup message from the stream.
|
||||
// XXX: It's tempting yet undesirable to accept `stream` by value,
|
||||
// since such a change will cause user-supplied &mut references to be consumed
|
||||
pub fn read(
|
||||
stream: &mut (impl io::Read + Unpin),
|
||||
) -> Result<Option<FeMessage>, ConnectionError> {
|
||||
Self::read_fut(&mut AsyncishRead(stream)).wait()
|
||||
}
|
||||
|
||||
/// Read startup message from the stream.
|
||||
// XXX: It's tempting yet undesirable to accept `stream` by value,
|
||||
// since such a change will cause user-supplied &mut references to be consumed
|
||||
pub fn read_fut<Reader>(
|
||||
stream: &mut Reader,
|
||||
) -> SyncFuture<Reader, impl Future<Output = Result<Option<FeMessage>, ConnectionError>> + '_>
|
||||
where
|
||||
Reader: tokio::io::AsyncRead + Unpin,
|
||||
{
|
||||
const MAX_STARTUP_PACKET_LENGTH: usize = 10000;
|
||||
const RESERVED_INVALID_MAJOR_VERSION: u32 = 1234;
|
||||
const CANCEL_REQUEST_CODE: u32 = 5678;
|
||||
const NEGOTIATE_SSL_CODE: u32 = 5679;
|
||||
const NEGOTIATE_GSS_CODE: u32 = 5680;
|
||||
|
||||
// need at least 4 bytes with packet len
|
||||
if buf.len() < 4 {
|
||||
let to_read = 4 - buf.len();
|
||||
buf.reserve(to_read);
|
||||
return Ok(None);
|
||||
}
|
||||
SyncFuture::new(async move {
|
||||
// Read length. If the connection is closed before reading anything (or before
|
||||
// reading 4 bytes, to be precise), return None to indicate that the connection
|
||||
// was closed. This matches the PostgreSQL server's behavior, which avoids noise
|
||||
// in the log if the client opens connection but closes it immediately.
|
||||
let len = match retry_read!(stream.read_u32().await) {
|
||||
Ok(len) => len as usize,
|
||||
Err(e) if e.kind() == io::ErrorKind::UnexpectedEof => return Ok(None),
|
||||
Err(e) => return Err(ConnectionError::Socket(e)),
|
||||
};
|
||||
|
||||
// We shouldn't advance `buf` as probably full message is not there yet,
|
||||
// so can't directly use Bytes::get_u32 etc.
|
||||
let len = (&buf[0..4]).read_u32::<BigEndian>().unwrap() as usize;
|
||||
if len < 4 || len > MAX_STARTUP_PACKET_LENGTH {
|
||||
return Err(ProtocolError::Protocol(format!(
|
||||
"invalid startup packet message length {}",
|
||||
len
|
||||
)));
|
||||
}
|
||||
|
||||
if buf.len() < len {
|
||||
// Don't have full message yet.
|
||||
let to_read = len - buf.len();
|
||||
buf.reserve(to_read);
|
||||
return Ok(None);
|
||||
}
|
||||
|
||||
// got the message, advance buffer
|
||||
let mut msg = buf.split_to(len).freeze();
|
||||
msg.advance(4); // consume len
|
||||
|
||||
let request_code = msg.get_u32();
|
||||
let req_hi = request_code >> 16;
|
||||
let req_lo = request_code & ((1 << 16) - 1);
|
||||
// StartupMessage, CancelRequest, SSLRequest etc are differentiated by request code.
|
||||
let message = match (req_hi, req_lo) {
|
||||
(RESERVED_INVALID_MAJOR_VERSION, CANCEL_REQUEST_CODE) => {
|
||||
if msg.remaining() != 8 {
|
||||
return Err(ProtocolError::BadMessage(
|
||||
"CancelRequest message is malformed, backend PID / secret key missing"
|
||||
.to_owned(),
|
||||
));
|
||||
}
|
||||
FeStartupPacket::CancelRequest(CancelKeyData {
|
||||
backend_pid: msg.get_i32(),
|
||||
cancel_key: msg.get_i32(),
|
||||
})
|
||||
}
|
||||
(RESERVED_INVALID_MAJOR_VERSION, NEGOTIATE_SSL_CODE) => {
|
||||
// Requested upgrade to SSL (aka TLS)
|
||||
FeStartupPacket::SslRequest
|
||||
}
|
||||
(RESERVED_INVALID_MAJOR_VERSION, NEGOTIATE_GSS_CODE) => {
|
||||
// Requested upgrade to GSSAPI
|
||||
FeStartupPacket::GssEncRequest
|
||||
}
|
||||
(RESERVED_INVALID_MAJOR_VERSION, unrecognized_code) => {
|
||||
return Err(ProtocolError::Protocol(format!(
|
||||
"Unrecognized request code {unrecognized_code}"
|
||||
#[allow(clippy::manual_range_contains)]
|
||||
if len < 4 || len > MAX_STARTUP_PACKET_LENGTH {
|
||||
return Err(ConnectionError::Protocol(format!(
|
||||
"invalid message length {len}"
|
||||
)));
|
||||
}
|
||||
// TODO bail if protocol major_version is not 3?
|
||||
(major_version, minor_version) => {
|
||||
// StartupMessage
|
||||
|
||||
// Parse pairs of null-terminated strings (key, value).
|
||||
// See `postgres: ProcessStartupPacket, build_startup_packet`.
|
||||
let mut tokens = str::from_utf8(&msg)
|
||||
.map_err(|_e| {
|
||||
ProtocolError::BadMessage("StartupMessage params: invalid utf-8".to_owned())
|
||||
})?
|
||||
.strip_suffix('\0') // drop packet's own null
|
||||
.ok_or_else(|| {
|
||||
ProtocolError::Protocol(
|
||||
"StartupMessage params: missing null terminator".to_string(),
|
||||
)
|
||||
})?
|
||||
.split_terminator('\0');
|
||||
let request_code =
|
||||
retry_read!(stream.read_u32().await).map_err(ConnectionError::Socket)?;
|
||||
|
||||
let mut params = HashMap::new();
|
||||
while let Some(name) = tokens.next() {
|
||||
let value = tokens.next().ok_or_else(|| {
|
||||
ProtocolError::Protocol(
|
||||
"StartupMessage params: key without value".to_string(),
|
||||
)
|
||||
})?;
|
||||
// the rest of startup packet are params
|
||||
let params_len = len - 8;
|
||||
let mut params_bytes = vec![0u8; params_len];
|
||||
stream
|
||||
.read_exact(params_bytes.as_mut())
|
||||
.await
|
||||
.map_err(ConnectionError::Socket)?;
|
||||
|
||||
params.insert(name.to_owned(), value.to_owned());
|
||||
// Parse params depending on request code
|
||||
let req_hi = request_code >> 16;
|
||||
let req_lo = request_code & ((1 << 16) - 1);
|
||||
let message = match (req_hi, req_lo) {
|
||||
(RESERVED_INVALID_MAJOR_VERSION, CANCEL_REQUEST_CODE) => {
|
||||
if params_len != 8 {
|
||||
return Err(ConnectionError::Protocol(
|
||||
"expected 8 bytes for CancelRequest params".to_string(),
|
||||
));
|
||||
}
|
||||
let mut cursor = Cursor::new(params_bytes);
|
||||
FeStartupPacket::CancelRequest(CancelKeyData {
|
||||
backend_pid: cursor.read_i32().await.map_err(ConnectionError::Socket)?,
|
||||
cancel_key: cursor.read_i32().await.map_err(ConnectionError::Socket)?,
|
||||
})
|
||||
}
|
||||
|
||||
FeStartupPacket::StartupMessage {
|
||||
major_version,
|
||||
minor_version,
|
||||
params: StartupMessageParams { params },
|
||||
(RESERVED_INVALID_MAJOR_VERSION, NEGOTIATE_SSL_CODE) => {
|
||||
// Requested upgrade to SSL (aka TLS)
|
||||
FeStartupPacket::SslRequest
|
||||
}
|
||||
}
|
||||
};
|
||||
Ok(Some(message))
|
||||
(RESERVED_INVALID_MAJOR_VERSION, NEGOTIATE_GSS_CODE) => {
|
||||
// Requested upgrade to GSSAPI
|
||||
FeStartupPacket::GssEncRequest
|
||||
}
|
||||
(RESERVED_INVALID_MAJOR_VERSION, unrecognized_code) => {
|
||||
return Err(ConnectionError::Protocol(format!(
|
||||
"Unrecognized request code {unrecognized_code}"
|
||||
)));
|
||||
}
|
||||
// TODO bail if protocol major_version is not 3?
|
||||
(major_version, minor_version) => {
|
||||
// Parse pairs of null-terminated strings (key, value).
|
||||
// See `postgres: ProcessStartupPacket, build_startup_packet`.
|
||||
let mut tokens = str::from_utf8(¶ms_bytes)
|
||||
.context("StartupMessage params: invalid utf-8")?
|
||||
.strip_suffix('\0') // drop packet's own null
|
||||
.ok_or_else(|| {
|
||||
ConnectionError::Protocol(
|
||||
"StartupMessage params: missing null terminator".to_string(),
|
||||
)
|
||||
})?
|
||||
.split_terminator('\0');
|
||||
|
||||
let mut params = HashMap::new();
|
||||
while let Some(name) = tokens.next() {
|
||||
let value = tokens.next().ok_or_else(|| {
|
||||
ConnectionError::Protocol(
|
||||
"StartupMessage params: key without value".to_string(),
|
||||
)
|
||||
})?;
|
||||
|
||||
params.insert(name.to_owned(), value.to_owned());
|
||||
}
|
||||
|
||||
FeStartupPacket::StartupMessage {
|
||||
major_version,
|
||||
minor_version,
|
||||
params: StartupMessageParams { params },
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
Ok(Some(FeMessage::StartupPacket(message)))
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl FeParseMessage {
|
||||
fn parse(mut buf: Bytes) -> Result<FeMessage, ProtocolError> {
|
||||
fn parse(mut buf: Bytes) -> anyhow::Result<FeMessage> {
|
||||
// FIXME: the rust-postgres driver uses a named prepared statement
|
||||
// for copy_out(). We're not prepared to handle that correctly. For
|
||||
// now, just ignore the statement name, assuming that the client never
|
||||
@@ -390,82 +452,55 @@ impl FeParseMessage {
|
||||
|
||||
let _pstmt_name = read_cstr(&mut buf)?;
|
||||
let query_string = read_cstr(&mut buf)?;
|
||||
if buf.remaining() < 2 {
|
||||
return Err(ProtocolError::BadMessage(
|
||||
"Parse message is malformed, nparams missing".to_string(),
|
||||
));
|
||||
}
|
||||
let nparams = buf.get_i16();
|
||||
|
||||
if nparams != 0 {
|
||||
return Err(ProtocolError::BadMessage(
|
||||
"query params not implemented".to_string(),
|
||||
));
|
||||
}
|
||||
ensure!(nparams == 0, "query params not implemented");
|
||||
|
||||
Ok(FeMessage::Parse(FeParseMessage { query_string }))
|
||||
}
|
||||
}
|
||||
|
||||
impl FeDescribeMessage {
|
||||
fn parse(mut buf: Bytes) -> Result<FeMessage, ProtocolError> {
|
||||
fn parse(mut buf: Bytes) -> anyhow::Result<FeMessage> {
|
||||
let kind = buf.get_u8();
|
||||
let _pstmt_name = read_cstr(&mut buf)?;
|
||||
|
||||
// FIXME: see FeParseMessage::parse
|
||||
if kind != b'S' {
|
||||
return Err(ProtocolError::BadMessage(
|
||||
"only prepared statemement Describe is implemented".to_string(),
|
||||
));
|
||||
}
|
||||
ensure!(
|
||||
kind == b'S',
|
||||
"only prepared statemement Describe is implemented"
|
||||
);
|
||||
|
||||
Ok(FeMessage::Describe(FeDescribeMessage { kind }))
|
||||
}
|
||||
}
|
||||
|
||||
impl FeExecuteMessage {
|
||||
fn parse(mut buf: Bytes) -> Result<FeMessage, ProtocolError> {
|
||||
fn parse(mut buf: Bytes) -> anyhow::Result<FeMessage> {
|
||||
let portal_name = read_cstr(&mut buf)?;
|
||||
if buf.remaining() < 4 {
|
||||
return Err(ProtocolError::BadMessage(
|
||||
"FeExecuteMessage message is malformed, maxrows missing".to_string(),
|
||||
));
|
||||
}
|
||||
let maxrows = buf.get_i32();
|
||||
|
||||
if !portal_name.is_empty() {
|
||||
return Err(ProtocolError::BadMessage(
|
||||
"named portals not implemented".to_string(),
|
||||
));
|
||||
}
|
||||
if maxrows != 0 {
|
||||
return Err(ProtocolError::BadMessage(
|
||||
"row limit in Execute message not implemented".to_string(),
|
||||
));
|
||||
}
|
||||
ensure!(portal_name.is_empty(), "named portals not implemented");
|
||||
ensure!(maxrows == 0, "row limit in Execute message not implemented");
|
||||
|
||||
Ok(FeMessage::Execute(FeExecuteMessage { maxrows }))
|
||||
}
|
||||
}
|
||||
|
||||
impl FeBindMessage {
|
||||
fn parse(mut buf: Bytes) -> Result<FeMessage, ProtocolError> {
|
||||
fn parse(mut buf: Bytes) -> anyhow::Result<FeMessage> {
|
||||
let portal_name = read_cstr(&mut buf)?;
|
||||
let _pstmt_name = read_cstr(&mut buf)?;
|
||||
|
||||
// FIXME: see FeParseMessage::parse
|
||||
if !portal_name.is_empty() {
|
||||
return Err(ProtocolError::BadMessage(
|
||||
"named portals not implemented".to_string(),
|
||||
));
|
||||
}
|
||||
ensure!(portal_name.is_empty(), "named portals not implemented");
|
||||
|
||||
Ok(FeMessage::Bind(FeBindMessage))
|
||||
}
|
||||
}
|
||||
|
||||
impl FeCloseMessage {
|
||||
fn parse(mut buf: Bytes) -> Result<FeMessage, ProtocolError> {
|
||||
fn parse(mut buf: Bytes) -> anyhow::Result<FeMessage> {
|
||||
let _kind = buf.get_u8();
|
||||
let _pstmt_or_portal_name = read_cstr(&mut buf)?;
|
||||
|
||||
@@ -494,7 +529,6 @@ pub enum BeMessage<'a> {
|
||||
CloseComplete,
|
||||
// None means column is NULL
|
||||
DataRow(&'a [Option<&'a [u8]>]),
|
||||
// None errcode means internal_error will be sent.
|
||||
ErrorResponse(&'a str, Option<&'a [u8; 5]>),
|
||||
/// Single byte - used in response to SSLRequest/GSSENCRequest.
|
||||
EncryptionResponse(bool),
|
||||
@@ -525,11 +559,6 @@ impl<'a> BeMessage<'a> {
|
||||
value: b"UTF8",
|
||||
};
|
||||
|
||||
pub const INTEGER_DATETIMES: Self = Self::ParameterStatus {
|
||||
name: b"integer_datetimes",
|
||||
value: b"on",
|
||||
};
|
||||
|
||||
/// Build a [`BeMessage::ParameterStatus`] holding the server version.
|
||||
pub fn server_version(version: &'a str) -> Self {
|
||||
Self::ParameterStatus {
|
||||
@@ -608,7 +637,7 @@ impl RowDescriptor<'_> {
|
||||
#[derive(Debug)]
|
||||
pub struct XLogDataBody<'a> {
|
||||
pub wal_start: u64,
|
||||
pub wal_end: u64, // current end of WAL on the server
|
||||
pub wal_end: u64,
|
||||
pub timestamp: i64,
|
||||
pub data: &'a [u8],
|
||||
}
|
||||
@@ -648,11 +677,12 @@ fn write_body<R>(buf: &mut BytesMut, f: impl FnOnce(&mut BytesMut) -> R) -> R {
|
||||
}
|
||||
|
||||
/// Safe write of s into buf as cstring (String in the protocol).
|
||||
fn write_cstr(s: impl AsRef<[u8]>, buf: &mut BytesMut) -> Result<(), ProtocolError> {
|
||||
fn write_cstr(s: impl AsRef<[u8]>, buf: &mut BytesMut) -> io::Result<()> {
|
||||
let bytes = s.as_ref();
|
||||
if bytes.contains(&0) {
|
||||
return Err(ProtocolError::BadMessage(
|
||||
"string contains embedded null".to_owned(),
|
||||
return Err(io::Error::new(
|
||||
io::ErrorKind::InvalidInput,
|
||||
"string contains embedded null",
|
||||
));
|
||||
}
|
||||
buf.put_slice(bytes);
|
||||
@@ -660,27 +690,22 @@ fn write_cstr(s: impl AsRef<[u8]>, buf: &mut BytesMut) -> Result<(), ProtocolErr
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Read cstring from buf, advancing it.
|
||||
fn read_cstr(buf: &mut Bytes) -> Result<Bytes, ProtocolError> {
|
||||
let pos = buf
|
||||
.iter()
|
||||
.position(|x| *x == 0)
|
||||
.ok_or_else(|| ProtocolError::BadMessage("missing cstring terminator".to_owned()))?;
|
||||
let result = buf.split_to(pos);
|
||||
fn read_cstr(buf: &mut Bytes) -> anyhow::Result<Bytes> {
|
||||
let pos = buf.iter().position(|x| *x == 0);
|
||||
let result = buf.split_to(pos.context("missing terminator")?);
|
||||
buf.advance(1); // drop the null terminator
|
||||
Ok(result)
|
||||
}
|
||||
|
||||
pub const SQLSTATE_INTERNAL_ERROR: &[u8; 5] = b"XX000";
|
||||
pub const SQLSTATE_SUCCESSFUL_COMPLETION: &[u8; 5] = b"00000";
|
||||
|
||||
impl<'a> BeMessage<'a> {
|
||||
/// Serialize `message` to the given `buf`.
|
||||
/// Apart from smart memory managemet, BytesMut is good here as msg len
|
||||
/// precedes its body and it is handy to write it down first and then fill
|
||||
/// the length. With Write we would have to either calc it manually or have
|
||||
/// one more buffer.
|
||||
pub fn write(buf: &mut BytesMut, message: &BeMessage) -> Result<(), ProtocolError> {
|
||||
/// Write message to the given buf.
|
||||
// Unlike the reading side, we use BytesMut
|
||||
// here as msg len precedes its body and it is handy to write it down first
|
||||
// and then fill the length. With Write we would have to either calc it
|
||||
// manually or have one more buffer.
|
||||
pub fn write(buf: &mut BytesMut, message: &BeMessage) -> io::Result<()> {
|
||||
match message {
|
||||
BeMessage::AuthenticationOk => {
|
||||
buf.put_u8(b'R');
|
||||
@@ -725,7 +750,7 @@ impl<'a> BeMessage<'a> {
|
||||
buf.put_slice(extra);
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
Ok::<_, io::Error>(())
|
||||
})?;
|
||||
}
|
||||
|
||||
@@ -829,7 +854,7 @@ impl<'a> BeMessage<'a> {
|
||||
write_cstr(error_msg, buf)?;
|
||||
|
||||
buf.put_u8(0); // terminator
|
||||
Ok(())
|
||||
Ok::<_, io::Error>(())
|
||||
})?;
|
||||
}
|
||||
|
||||
@@ -852,7 +877,7 @@ impl<'a> BeMessage<'a> {
|
||||
write_cstr(error_msg.as_bytes(), buf)?;
|
||||
|
||||
buf.put_u8(0); // terminator
|
||||
Ok(())
|
||||
Ok::<_, io::Error>(())
|
||||
})?;
|
||||
}
|
||||
|
||||
@@ -907,7 +932,7 @@ impl<'a> BeMessage<'a> {
|
||||
buf.put_i32(-1); /* typmod */
|
||||
buf.put_i16(0); /* format code */
|
||||
}
|
||||
Ok(())
|
||||
Ok::<_, io::Error>(())
|
||||
})?;
|
||||
}
|
||||
|
||||
@@ -974,7 +999,7 @@ impl ReplicationFeedback {
|
||||
// null-terminated string - key,
|
||||
// uint32 - value length in bytes
|
||||
// value itself
|
||||
pub fn serialize(&self, buf: &mut BytesMut) {
|
||||
pub fn serialize(&self, buf: &mut BytesMut) -> Result<()> {
|
||||
buf.put_u8(REPLICATION_FEEDBACK_FIELDS_NUMBER); // # of keys
|
||||
buf.put_slice(b"current_timeline_size\0");
|
||||
buf.put_i32(8);
|
||||
@@ -999,6 +1024,7 @@ impl ReplicationFeedback {
|
||||
buf.put_slice(b"ps_replytime\0");
|
||||
buf.put_i32(8);
|
||||
buf.put_i64(timestamp);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
// Deserialize ReplicationFeedback message
|
||||
@@ -1066,7 +1092,7 @@ mod tests {
|
||||
// because it is rounded up to microseconds during serialization.
|
||||
rf.ps_replytime = *PG_EPOCH + Duration::from_secs(100_000_000);
|
||||
let mut data = BytesMut::new();
|
||||
rf.serialize(&mut data);
|
||||
rf.serialize(&mut data).unwrap();
|
||||
|
||||
let rf_parsed = ReplicationFeedback::parse(data.freeze());
|
||||
assert_eq!(rf, rf_parsed);
|
||||
@@ -1081,7 +1107,7 @@ mod tests {
|
||||
// because it is rounded up to microseconds during serialization.
|
||||
rf.ps_replytime = *PG_EPOCH + Duration::from_secs(100_000_000);
|
||||
let mut data = BytesMut::new();
|
||||
rf.serialize(&mut data);
|
||||
rf.serialize(&mut data).unwrap();
|
||||
|
||||
// Add an extra field to the buffer and adjust number of keys
|
||||
if let Some(first) = data.first_mut() {
|
||||
@@ -1123,6 +1149,15 @@ mod tests {
|
||||
let params = make_params("foo\\ bar \\ \\\\ baz\\ lol");
|
||||
assert_eq!(split_options(¶ms), ["foo bar", " \\", "baz ", "lol"]);
|
||||
}
|
||||
|
||||
// Make sure that `read` is sync/async callable
|
||||
async fn _assert(stream: &mut (impl tokio::io::AsyncRead + Unpin)) {
|
||||
let _ = FeMessage::read(&mut [].as_ref());
|
||||
let _ = FeMessage::read_fut(stream).await;
|
||||
|
||||
let _ = FeStartupPacket::read(&mut [].as_ref());
|
||||
let _ = FeStartupPacket::read_fut(stream).await;
|
||||
}
|
||||
}
|
||||
|
||||
fn terminate_code(code: &[u8; 5]) -> [u8; 6] {
|
||||
|
||||
179
libs/pq_proto/src/sync.rs
Normal file
179
libs/pq_proto/src/sync.rs
Normal file
@@ -0,0 +1,179 @@
|
||||
use pin_project_lite::pin_project;
|
||||
use std::future::Future;
|
||||
use std::marker::PhantomData;
|
||||
use std::pin::Pin;
|
||||
use std::{io, task};
|
||||
|
||||
pin_project! {
|
||||
/// We use this future to mark certain methods
|
||||
/// as callable in both sync and async modes.
|
||||
#[repr(transparent)]
|
||||
pub struct SyncFuture<S, T: Future> {
|
||||
#[pin]
|
||||
inner: T,
|
||||
_marker: PhantomData<S>,
|
||||
}
|
||||
}
|
||||
|
||||
/// This wrapper lets us synchronously wait for inner future's completion
|
||||
/// (see [`SyncFuture::wait`]) **provided that `S` implements [`SyncProof`]**.
|
||||
/// For instance, `S` may be substituted with types implementing
|
||||
/// [`tokio::io::AsyncRead`], but it's not the only viable option.
|
||||
impl<S, T: Future> SyncFuture<S, T> {
|
||||
/// NOTE: caller should carefully pick a type for `S`,
|
||||
/// because we don't want to enable [`SyncFuture::wait`] when
|
||||
/// it's in fact impossible to run the future synchronously.
|
||||
/// Violation of this contract will not cause UB, but
|
||||
/// panics and async event loop freezes won't please you.
|
||||
///
|
||||
/// Example:
|
||||
///
|
||||
/// ```
|
||||
/// # use pq_proto::sync::SyncFuture;
|
||||
/// # use std::future::Future;
|
||||
/// # use tokio::io::AsyncReadExt;
|
||||
/// #
|
||||
/// // Parse a pair of numbers from a stream
|
||||
/// pub fn parse_pair<Reader>(
|
||||
/// stream: &mut Reader,
|
||||
/// ) -> SyncFuture<Reader, impl Future<Output = anyhow::Result<(u32, u64)>> + '_>
|
||||
/// where
|
||||
/// Reader: tokio::io::AsyncRead + Unpin,
|
||||
/// {
|
||||
/// // If `Reader` is a `SyncProof`, this will give caller
|
||||
/// // an opportunity to use `SyncFuture::wait`, because
|
||||
/// // `.await` will always result in `Poll::Ready`.
|
||||
/// SyncFuture::new(async move {
|
||||
/// let x = stream.read_u32().await?;
|
||||
/// let y = stream.read_u64().await?;
|
||||
/// Ok((x, y))
|
||||
/// })
|
||||
/// }
|
||||
/// ```
|
||||
pub fn new(inner: T) -> Self {
|
||||
Self {
|
||||
inner,
|
||||
_marker: PhantomData,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<S, T: Future> Future for SyncFuture<S, T> {
|
||||
type Output = T::Output;
|
||||
|
||||
/// In async code, [`SyncFuture`] behaves like a regular wrapper.
|
||||
#[inline(always)]
|
||||
fn poll(self: Pin<&mut Self>, cx: &mut task::Context<'_>) -> task::Poll<Self::Output> {
|
||||
self.project().inner.poll(cx)
|
||||
}
|
||||
}
|
||||
|
||||
/// Postulates that we can call [`SyncFuture::wait`].
|
||||
/// If implementer is also a [`Future`], it should always
|
||||
/// return [`task::Poll::Ready`] from [`Future::poll`].
|
||||
///
|
||||
/// Each implementation should document which futures
|
||||
/// specifically are being declared sync-proof.
|
||||
pub trait SyncPostulate {}
|
||||
|
||||
impl<T: SyncPostulate> SyncPostulate for &T {}
|
||||
impl<T: SyncPostulate> SyncPostulate for &mut T {}
|
||||
|
||||
impl<P: SyncPostulate, T: Future> SyncFuture<P, T> {
|
||||
/// Synchronously wait for future completion.
|
||||
pub fn wait(mut self) -> T::Output {
|
||||
const RAW_WAKER: task::RawWaker = task::RawWaker::new(
|
||||
std::ptr::null(),
|
||||
&task::RawWakerVTable::new(
|
||||
|_| RAW_WAKER,
|
||||
|_| panic!("SyncFuture: failed to wake"),
|
||||
|_| panic!("SyncFuture: failed to wake by ref"),
|
||||
|_| { /* drop is no-op */ },
|
||||
),
|
||||
);
|
||||
|
||||
// SAFETY: We never move `self` during this call;
|
||||
// furthermore, it will be dropped in the end regardless of panics
|
||||
let this = unsafe { Pin::new_unchecked(&mut self) };
|
||||
|
||||
// SAFETY: This waker doesn't do anything apart from panicking
|
||||
let waker = unsafe { task::Waker::from_raw(RAW_WAKER) };
|
||||
let context = &mut task::Context::from_waker(&waker);
|
||||
|
||||
match this.poll(context) {
|
||||
task::Poll::Ready(res) => res,
|
||||
_ => panic!("SyncFuture: unexpected pending!"),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// This wrapper turns any [`std::io::Read`] into a blocking [`tokio::io::AsyncRead`],
|
||||
/// which lets us abstract over sync & async readers in methods returning [`SyncFuture`].
|
||||
/// NOTE: you **should not** use this in async code.
|
||||
#[repr(transparent)]
|
||||
pub struct AsyncishRead<T: io::Read + Unpin>(pub T);
|
||||
|
||||
/// This lets us call [`SyncFuture<AsyncishRead<_>, _>::wait`],
|
||||
/// and allows the future to await on any of the [`AsyncRead`]
|
||||
/// and [`AsyncReadExt`] methods on `AsyncishRead`.
|
||||
impl<T: io::Read + Unpin> SyncPostulate for AsyncishRead<T> {}
|
||||
|
||||
impl<T: io::Read + Unpin> tokio::io::AsyncRead for AsyncishRead<T> {
|
||||
#[inline(always)]
|
||||
fn poll_read(
|
||||
mut self: Pin<&mut Self>,
|
||||
_cx: &mut task::Context<'_>,
|
||||
buf: &mut tokio::io::ReadBuf<'_>,
|
||||
) -> task::Poll<io::Result<()>> {
|
||||
task::Poll::Ready(
|
||||
// `Read::read` will block, meaning we don't need a real event loop!
|
||||
self.0
|
||||
.read(buf.initialize_unfilled())
|
||||
.map(|sz| buf.advance(sz)),
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use tokio::io::{AsyncReadExt, AsyncWriteExt};
|
||||
|
||||
// async helper(stream: &mut impl AsyncRead) -> io::Result<u32>
|
||||
fn bytes_add<Reader>(
|
||||
stream: &mut Reader,
|
||||
) -> SyncFuture<Reader, impl Future<Output = io::Result<u32>> + '_>
|
||||
where
|
||||
Reader: tokio::io::AsyncRead + Unpin,
|
||||
{
|
||||
SyncFuture::new(async move {
|
||||
let a = stream.read_u32().await?;
|
||||
let b = stream.read_u32().await?;
|
||||
Ok(a + b)
|
||||
})
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_sync() {
|
||||
let bytes = [100u32.to_be_bytes(), 200u32.to_be_bytes()].concat();
|
||||
let res = bytes_add(&mut AsyncishRead(&mut &bytes[..]))
|
||||
.wait()
|
||||
.unwrap();
|
||||
assert_eq!(res, 300);
|
||||
}
|
||||
|
||||
// We need a single-threaded executor for this test
|
||||
#[tokio::test(flavor = "current_thread")]
|
||||
async fn test_async() {
|
||||
let (mut tx, mut rx) = tokio::net::UnixStream::pair().unwrap();
|
||||
|
||||
let write = async move {
|
||||
tx.write_u32(100).await?;
|
||||
tx.write_u32(200).await?;
|
||||
Ok(())
|
||||
};
|
||||
|
||||
let (res, ()) = tokio::try_join!(bytes_add(&mut rx), write).unwrap();
|
||||
assert_eq!(res, 300);
|
||||
}
|
||||
}
|
||||
@@ -111,7 +111,7 @@ pub trait RemoteStorage: Send + Sync + 'static {
|
||||
}
|
||||
|
||||
pub struct Download {
|
||||
pub download_stream: Pin<Box<dyn io::AsyncRead + Unpin + Send + Sync>>,
|
||||
pub download_stream: Pin<Box<dyn io::AsyncRead + Unpin + Send>>,
|
||||
/// Extra key-value data, associated with the current remote file.
|
||||
pub metadata: Option<StorageMetadata>,
|
||||
}
|
||||
|
||||
@@ -12,37 +12,42 @@ anyhow.workspace = true
|
||||
bincode.workspace = true
|
||||
bytes.workspace = true
|
||||
heapless.workspace = true
|
||||
hex = { workspace = true, features = ["serde"] }
|
||||
hyper = { workspace = true, features = ["full"] }
|
||||
futures = { workspace = true}
|
||||
jsonwebtoken.workspace = true
|
||||
nix.workspace = true
|
||||
once_cell.workspace = true
|
||||
pin-project-lite.workspace = true
|
||||
routerify.workspace = true
|
||||
serde.workspace = true
|
||||
serde_json.workspace = true
|
||||
signal-hook.workspace = true
|
||||
thiserror.workspace = true
|
||||
tokio.workspace = true
|
||||
tokio-rustls.workspace = true
|
||||
tracing.workspace = true
|
||||
tracing-subscriber = { workspace = true, features = ["json"] }
|
||||
nix.workspace = true
|
||||
signal-hook.workspace = true
|
||||
rand.workspace = true
|
||||
jsonwebtoken.workspace = true
|
||||
hex = { workspace = true, features = ["serde"] }
|
||||
rustls.workspace = true
|
||||
rustls-split.workspace = true
|
||||
git-version.workspace = true
|
||||
serde_with.workspace = true
|
||||
once_cell.workspace = true
|
||||
strum.workspace = true
|
||||
strum_macros.workspace = true
|
||||
url.workspace = true
|
||||
uuid = { version = "1.2", features = ["v4", "serde"] }
|
||||
|
||||
metrics.workspace = true
|
||||
workspace_hack.workspace = true
|
||||
pq_proto.workspace = true
|
||||
|
||||
workspace_hack.workspace = true
|
||||
url.workspace = true
|
||||
uuid = { version = "1.2", features = ["v4", "serde"] }
|
||||
[dev-dependencies]
|
||||
byteorder.workspace = true
|
||||
bytes.workspace = true
|
||||
criterion.workspace = true
|
||||
hex-literal.workspace = true
|
||||
tempfile.workspace = true
|
||||
criterion.workspace = true
|
||||
rustls-pemfile.workspace = true
|
||||
|
||||
[[bench]]
|
||||
name = "benchmarks"
|
||||
|
||||
@@ -9,28 +9,16 @@ use std::path::Path;
|
||||
|
||||
use anyhow::Result;
|
||||
use jsonwebtoken::{
|
||||
decode, encode, Algorithm, Algorithm::*, DecodingKey, EncodingKey, Header, TokenData,
|
||||
Validation,
|
||||
decode, encode, Algorithm, DecodingKey, EncodingKey, Header, TokenData, Validation,
|
||||
};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use serde_with::{serde_as, DisplayFromStr};
|
||||
|
||||
use crate::id::TenantId;
|
||||
|
||||
/// Algorithms accepted during validation.
|
||||
///
|
||||
/// Accept all RSA-based algorithms. We pass this list to jsonwebtoken::decode,
|
||||
/// which checks that the algorithm in the token is one of these.
|
||||
///
|
||||
/// XXX: It also fails the validation if there are any algorithms in this list that belong
|
||||
/// to different family than the token's algorithm. In other words, we can *not* list any
|
||||
/// non-RSA algorithms here, or the validation always fails with InvalidAlgorithm error.
|
||||
const ACCEPTED_ALGORITHMS: &[Algorithm] = &[RS256, RS384, RS512];
|
||||
const JWT_ALGORITHM: Algorithm = Algorithm::RS256;
|
||||
|
||||
/// Algorithm to use when generating a new token in [`encode_from_key_file`]
|
||||
const ENCODE_ALGORITHM: Algorithm = Algorithm::RS256;
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
|
||||
#[derive(Debug, Serialize, Deserialize, Clone)]
|
||||
#[serde(rename_all = "lowercase")]
|
||||
pub enum Scope {
|
||||
// Provides access to all data for a specific tenant (specified in `struct Claims` below)
|
||||
@@ -45,9 +33,8 @@ pub enum Scope {
|
||||
SafekeeperData,
|
||||
}
|
||||
|
||||
/// JWT payload. See docs/authentication.md for the format
|
||||
#[serde_as]
|
||||
#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
|
||||
#[derive(Debug, Serialize, Deserialize, Clone)]
|
||||
pub struct Claims {
|
||||
#[serde(default)]
|
||||
#[serde_as(as = "Option<DisplayFromStr>")]
|
||||
@@ -68,8 +55,7 @@ pub struct JwtAuth {
|
||||
|
||||
impl JwtAuth {
|
||||
pub fn new(decoding_key: DecodingKey) -> Self {
|
||||
let mut validation = Validation::default();
|
||||
validation.algorithms = ACCEPTED_ALGORITHMS.into();
|
||||
let mut validation = Validation::new(JWT_ALGORITHM);
|
||||
// The default 'required_spec_claims' is 'exp'. But we don't want to require
|
||||
// expiration.
|
||||
validation.required_spec_claims = [].into();
|
||||
@@ -100,113 +86,5 @@ impl std::fmt::Debug for JwtAuth {
|
||||
// this function is used only for testing purposes in CLI e g generate tokens during init
|
||||
pub fn encode_from_key_file(claims: &Claims, key_data: &[u8]) -> Result<String> {
|
||||
let key = EncodingKey::from_rsa_pem(key_data)?;
|
||||
Ok(encode(&Header::new(ENCODE_ALGORITHM), claims, &key)?)
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use std::str::FromStr;
|
||||
|
||||
// generated with:
|
||||
//
|
||||
// openssl genpkey -algorithm rsa -out storage-auth-priv.pem
|
||||
// openssl pkey -in storage-auth-priv.pem -pubout -out storage-auth-pub.pem
|
||||
const TEST_PUB_KEY_RSA: &[u8] = br#"
|
||||
-----BEGIN PUBLIC KEY-----
|
||||
MIIBIjANBgkqhkiG9w0BAQEFAAOCAQ8AMIIBCgKCAQEAy6OZ+/kQXcueVJA/KTzO
|
||||
v4ljxylc/Kcb0sXWuXg1GB8k3nDA1gK66LFYToH0aTnqrnqG32Vu6wrhwuvqsZA7
|
||||
jQvP0ZePAbWhpEqho7EpNunDPcxZ/XDy5TQlB1P58F9I3lkJXDC+DsHYLuuzwhAv
|
||||
vo2MtWRdYlVHblCVLyZtANHhUMp2HUhgjHnJh5UrLIKOl4doCBxkM3rK0wjKsNCt
|
||||
M92PCR6S9rvYzldfeAYFNppBkEQrXt2CgUqZ4KaS4LXtjTRUJxljijA4HWffhxsr
|
||||
euRu3ufq8kVqie7fum0rdZZSkONmce0V0LesQ4aE2jB+2Sn48h6jb4dLXGWdq8TV
|
||||
wQIDAQAB
|
||||
-----END PUBLIC KEY-----
|
||||
"#;
|
||||
const TEST_PRIV_KEY_RSA: &[u8] = br#"
|
||||
-----BEGIN PRIVATE KEY-----
|
||||
MIIEvQIBADANBgkqhkiG9w0BAQEFAASCBKcwggSjAgEAAoIBAQDLo5n7+RBdy55U
|
||||
kD8pPM6/iWPHKVz8pxvSxda5eDUYHyTecMDWArrosVhOgfRpOequeobfZW7rCuHC
|
||||
6+qxkDuNC8/Rl48BtaGkSqGjsSk26cM9zFn9cPLlNCUHU/nwX0jeWQlcML4Owdgu
|
||||
67PCEC++jYy1ZF1iVUduUJUvJm0A0eFQynYdSGCMecmHlSssgo6Xh2gIHGQzesrT
|
||||
CMqw0K0z3Y8JHpL2u9jOV194BgU2mkGQRCte3YKBSpngppLgte2NNFQnGWOKMDgd
|
||||
Z9+HGyt65G7e5+ryRWqJ7t+6bSt1llKQ42Zx7RXQt6xDhoTaMH7ZKfjyHqNvh0tc
|
||||
ZZ2rxNXBAgMBAAECggEAVz3u4Wlx3o02dsoZlSQs+xf0PEX3RXKeU+1YMbtTG9Nz
|
||||
6yxpIQaoZrpbt76rJE2gwkFR+PEu1NmjoOuLb6j4KlQuI4AHz1auOoGSwFtM6e66
|
||||
K4aZ4x95oEJ3vqz2fkmEIWYJwYpMUmwvnuJx76kZm0xvROMLsu4QHS2+zCVtO5Tr
|
||||
hvS05IMVuZ2TdQBZw0+JaFdwXbgDjQnQGY5n9MoTWSx1a4s/FF4Eby65BbDutcpn
|
||||
Vt3jQAOmO1X2kbPeWSGuPJRzyUs7Kg8qfeglBIR3ppGP3vPYAdWX+ho00bmsVkSp
|
||||
Q8vjul6C3WiM+kjwDxotHSDgbl/xldAl7OqPh0bfAQKBgQDnycXuq14Vg8nZvyn9
|
||||
rTnvucO8RBz5P6G+FZ+44cAS2x79+85onARmMnm+9MKYLSMo8fOvsK034NDI68XM
|
||||
04QQ/vlfouvFklMTGJIurgEImTZbGCmlMYCvFyIxaEWixon8OpeI4rFe4Hmbiijh
|
||||
PxhxWg221AwvBS2sco8J/ylEkQKBgQDg6Rh2QYb/j0Wou1rJPbuy3NhHofd5Rq35
|
||||
4YV3f2lfVYcPrgRhwe3T9SVII7Dx8LfwzsX5TAlf48ESlI3Dzv40uOCDM+xdtBRI
|
||||
r96SfSm+jup6gsXU3AsdNkrRK3HoOG9Z/TkrUp213QAIlVnvIx65l4ckFMlpnPJ0
|
||||
lo1LDXZWMQKBgFArzjZ7N5OhfdO+9zszC3MLgdRAivT7OWqR+CjujIz5FYMr8Xzl
|
||||
WfAvTUTrS9Nu6VZkObFvHrrRG+YjBsuN7YQjbQXTSFGSBwH34bgbn2fl9pMTjHQC
|
||||
50uoaL9GHa/rlBaV/YvvPQJgCi/uXa1rMX0jdNLkDULGO8IF7cu7Yf7BAoGBAIUU
|
||||
J29BkpmAst0GDs/ogTlyR18LTR0rXyHt+UUd1MGeH859TwZw80JpWWf4BmkB4DTS
|
||||
hH3gKePdJY7S65ci0XNsuRupC4DeXuorde0DtkGU2tUmr9wlX0Ynq9lcdYfMbMa4
|
||||
eK1TsxG69JwfkxlWlIWITWRiEFM3lJa7xlrUWmLhAoGAFpKWF/hn4zYg3seU9gai
|
||||
EYHKSbhxA4mRb+F0/9IlCBPMCqFrL5yftUsYIh2XFKn8+QhO97Nmk8wJSK6TzQ5t
|
||||
ZaSRmgySrUUhx4nZ/MgqWCFv8VUbLM5MBzwxPKhXkSTfR4z2vLYLJwVY7Tb4kZtp
|
||||
8ismApXVGHpOCstzikV9W7k=
|
||||
-----END PRIVATE KEY-----
|
||||
"#;
|
||||
|
||||
#[test]
|
||||
fn test_decode() -> Result<(), anyhow::Error> {
|
||||
let expected_claims = Claims {
|
||||
tenant_id: Some(TenantId::from_str("3d1f7595b468230304e0b73cecbcb081")?),
|
||||
scope: Scope::Tenant,
|
||||
};
|
||||
|
||||
// Here are tokens containing the following payload, signed using TEST_PRIV_KEY_RSA
|
||||
// using RS512, RS384 and RS256 algorithms:
|
||||
//
|
||||
// ```
|
||||
// {
|
||||
// "scope": "tenant",
|
||||
// "tenant_id": "3d1f7595b468230304e0b73cecbcb081",
|
||||
// "iss": "neon.controlplane",
|
||||
// "exp": 1709200879,
|
||||
// "iat": 1678442479
|
||||
// }
|
||||
// ```
|
||||
//
|
||||
// These were encoded with the online debugger at https://jwt.io
|
||||
//
|
||||
let encoded_rs512 = "eyJhbGciOiJSUzUxMiIsInR5cCI6IkpXVCJ9.eyJzY29wZSI6InRlbmFudCIsInRlbmFudF9pZCI6IjNkMWY3NTk1YjQ2ODIzMDMwNGUwYjczY2VjYmNiMDgxIiwiaXNzIjoibmVvbi5jb250cm9scGxhbmUiLCJleHAiOjE3MDkyMDA4NzksImlhdCI6MTY3ODQ0MjQ3OX0.QmqfteDQmDGoxQ5EFkasbt35Lx0W0Nh63muQnYZvFq93DSh4ZbOG9Mc4yaiXZoiS5HgeKtFKv3mbWkDqjz3En06aY17hWwguBtAsGASX48lYeCPADYGlGAuaWnOnVRwe3iiOC7tvPFvwX_45S84X73sNUXyUiXv6nLdcDqVXudtNrGST_DnZDnjuUJX11w7sebtKqQQ8l9-iGHiXOl5yevpMCoB1OcTWcT6DfDtffoNuMHDC3fyhmEGG5oKAt1qBybqAIiyC9-UBAowRZXhdfxrzUl-I9jzKWvk85c5ulhVRwbPeP6TTTlPKwFzBNHg1i2U-1GONew5osQ3aoptwsA";
|
||||
|
||||
let encoded_rs384 = "eyJhbGciOiJSUzM4NCIsInR5cCI6IkpXVCJ9.eyJzY29wZSI6InRlbmFudCIsInRlbmFudF9pZCI6IjNkMWY3NTk1YjQ2ODIzMDMwNGUwYjczY2VjYmNiMDgxIiwiaXNzIjoibmVvbi5jb250cm9scGxhbmUiLCJleHAiOjE3MDkyMDA4NzksImlhdCI6MTY3ODQ0MjQ3OX0.qqk4nkxKzOJP38c_g57_w_SfdQVmCsDT_bsLmdFj_N6LIB22gr6U6_P_5mvk3pIAsp0VCTDwPrCU908TxqjibEkwvQoJwbogHamSGHpD7eJBxGblSnA-Nr3MlEMxpFtec8QokSm6C5mH7DoBYjB2xzeOlxAmpR2GAzInKiMkU4kZ_OcqqrmVcMXY_6VnbxZWMekuw56zE1-PP_qNF1HvYOH-P08ONP8qdo5UPtBG7QBEFlCqZXJZCFihQaI4Vzil9rDuZGCm3I7xQJ8-yh1PX3BTbGo8EzqLdRyBeTpr08UTuRbp_MJDWevHpP3afvJetAItqZXIoZQrbJjcByHqKw";
|
||||
|
||||
let encoded_rs256 = "eyJhbGciOiJSUzI1NiIsInR5cCI6IkpXVCJ9.eyJzY29wZSI6InRlbmFudCIsInRlbmFudF9pZCI6IjNkMWY3NTk1YjQ2ODIzMDMwNGUwYjczY2VjYmNiMDgxIiwiaXNzIjoibmVvbi5jb250cm9scGxhbmUiLCJleHAiOjE3MDkyMDA4NzksImlhdCI6MTY3ODQ0MjQ3OX0.dF2N9KXG8ftFKHYbd5jQtXMQqv0Ej8FISGp1b_dmqOCotXj5S1y2AWjwyB_EXHM77JXfbEoJPAPrFFBNfd8cWtkCSTvpxWoHaecGzegDFGv5ZSc5AECFV1Daahc3PI3jii9wEiGkFOiwiBNfZ5INomOAsV--XXxlqIwKbTcgSYI7lrOTfecXAbAHiMKQlQYiIBSGnytRCgafhRkyGzPAL8ismthFJ9RHfeejyskht-9GbVHURw02bUyijuHEulpf9eEY3ZiB28de6jnCdU7ftIYaUMaYWt0nZQGkzxKPSfSLZNy14DTOYLDS04DVstWQPqnCUW_ojg0wJETOOfo9Zw";
|
||||
|
||||
// Check that RS512, RS384 and RS256 tokens can all be validated
|
||||
let auth = JwtAuth::new(DecodingKey::from_rsa_pem(TEST_PUB_KEY_RSA)?);
|
||||
|
||||
for encoded in [encoded_rs512, encoded_rs384, encoded_rs256] {
|
||||
let claims_from_token = auth.decode(encoded)?.claims;
|
||||
assert_eq!(claims_from_token, expected_claims);
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_encode() -> Result<(), anyhow::Error> {
|
||||
let claims = Claims {
|
||||
tenant_id: Some(TenantId::from_str("3d1f7595b468230304e0b73cecbcb081")?),
|
||||
scope: Scope::Tenant,
|
||||
};
|
||||
|
||||
let encoded = encode_from_key_file(&claims, TEST_PRIV_KEY_RSA)?;
|
||||
|
||||
// decode it back
|
||||
let auth = JwtAuth::new(DecodingKey::from_rsa_pem(TEST_PUB_KEY_RSA)?);
|
||||
let decoded = auth.decode(&encoded)?;
|
||||
|
||||
assert_eq!(decoded.claims, claims);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
Ok(encode(&Header::new(JWT_ALGORITHM), claims, &key)?)
|
||||
}
|
||||
|
||||
@@ -11,7 +11,7 @@ where
|
||||
P: AsRef<Path>,
|
||||
{
|
||||
fn is_empty_dir(&self) -> io::Result<bool> {
|
||||
Ok(fs::read_dir(self)?.next().is_none())
|
||||
Ok(fs::read_dir(self)?.into_iter().next().is_none())
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -3,14 +3,14 @@ use crate::http::error;
|
||||
use anyhow::{anyhow, Context};
|
||||
use hyper::header::{HeaderName, AUTHORIZATION};
|
||||
use hyper::http::HeaderValue;
|
||||
use hyper::Method;
|
||||
use hyper::{header::CONTENT_TYPE, Body, Request, Response, Server};
|
||||
use hyper::{Method, StatusCode};
|
||||
use metrics::{register_int_counter, Encoder, IntCounter, TextEncoder};
|
||||
use once_cell::sync::Lazy;
|
||||
use routerify::ext::RequestExt;
|
||||
use routerify::{Middleware, RequestInfo, Router, RouterBuilder, RouterService};
|
||||
use tokio::task::JoinError;
|
||||
use tracing::{self, debug, info, info_span, warn, Instrument};
|
||||
use tracing;
|
||||
|
||||
use std::future::Future;
|
||||
use std::net::TcpListener;
|
||||
@@ -32,77 +32,31 @@ static X_REQUEST_ID_HEADER: HeaderName = HeaderName::from_static(X_REQUEST_ID_HE
|
||||
#[derive(Debug, Default, Clone)]
|
||||
struct RequestId(String);
|
||||
|
||||
/// Adds a tracing info_span! instrumentation around the handler events,
|
||||
/// logs the request start and end events for non-GET requests and non-200 responses.
|
||||
///
|
||||
/// Use this to distinguish between logs of different HTTP requests: every request handler wrapped
|
||||
/// in this type will get request info logged in the wrapping span, including the unique request ID.
|
||||
///
|
||||
/// There could be other ways to implement similar functionality:
|
||||
///
|
||||
/// * procmacros placed on top of all handler methods
|
||||
/// With all the drawbacks of procmacros, brings no difference implementation-wise,
|
||||
/// and little code reduction compared to the existing approach.
|
||||
///
|
||||
/// * Another `TraitExt` with e.g. the `get_with_span`, `post_with_span` methods to do similar logic,
|
||||
/// implemented for [`RouterBuilder`].
|
||||
/// Could be simpler, but we don't want to depend on [`routerify`] more, targeting to use other library later.
|
||||
///
|
||||
/// * In theory, a span guard could've been created in a pre-request middleware and placed into a global collection, to be dropped
|
||||
/// later, in a post-response middleware.
|
||||
/// Due to suspendable nature of the futures, would give contradictive results which is exactly the opposite of what `tracing-futures`
|
||||
/// tries to achive with its `.instrument` used in the current approach.
|
||||
///
|
||||
/// If needed, a declarative macro to substitute the |r| ... closure boilerplate could be introduced.
|
||||
pub struct RequestSpan<E, R, H>(pub H)
|
||||
where
|
||||
E: Into<Box<dyn std::error::Error + Send + Sync>> + 'static,
|
||||
R: Future<Output = Result<Response<Body>, E>> + Send + 'static,
|
||||
H: Fn(Request<Body>) -> R + Send + Sync + 'static;
|
||||
async fn logger(res: Response<Body>, info: RequestInfo) -> Result<Response<Body>, ApiError> {
|
||||
let request_id = info.context::<RequestId>().unwrap_or_default().0;
|
||||
|
||||
impl<E, R, H> RequestSpan<E, R, H>
|
||||
where
|
||||
E: Into<Box<dyn std::error::Error + Send + Sync>> + 'static,
|
||||
R: Future<Output = Result<Response<Body>, E>> + Send + 'static,
|
||||
H: Fn(Request<Body>) -> R + Send + Sync + 'static,
|
||||
{
|
||||
/// Creates a tracing span around inner request handler and executes the request handler in the contex of that span.
|
||||
/// Use as `|r| RequestSpan(my_handler).handle(r)` instead of `my_handler` as the request handler to get the span enabled.
|
||||
pub async fn handle(self, request: Request<Body>) -> Result<Response<Body>, E> {
|
||||
let request_id = request.context::<RequestId>().unwrap_or_default().0;
|
||||
let method = request.method();
|
||||
let path = request.uri().path();
|
||||
let request_span = info_span!("request", %method, %path, %request_id);
|
||||
// cannot factor out the Level to avoid the repetition
|
||||
// because tracing can only work with const Level
|
||||
// which is not the case here
|
||||
|
||||
let log_quietly = method == Method::GET;
|
||||
async move {
|
||||
if log_quietly {
|
||||
debug!("Handling request");
|
||||
} else {
|
||||
info!("Handling request");
|
||||
}
|
||||
|
||||
// Note that we reuse `error::handler` here and not returning and error at all,
|
||||
// yet cannot use `!` directly in the method signature due to `routerify::RouterBuilder` limitation.
|
||||
// Usage of the error handler also means that we expect only the `ApiError` errors to be raised in this call.
|
||||
//
|
||||
// Panics are not handled separately, there's a `tracing_panic_hook` from another module to do that globally.
|
||||
match (self.0)(request).await {
|
||||
Ok(response) => {
|
||||
let response_status = response.status();
|
||||
if log_quietly && response_status.is_success() {
|
||||
debug!("Request handled, status: {response_status}");
|
||||
} else {
|
||||
info!("Request handled, status: {response_status}");
|
||||
}
|
||||
Ok(response)
|
||||
}
|
||||
Err(e) => Ok(error::handler(e.into()).await),
|
||||
}
|
||||
}
|
||||
.instrument(request_span)
|
||||
.await
|
||||
if info.method() == Method::GET && res.status() == StatusCode::OK {
|
||||
tracing::debug!(
|
||||
"{} {} {} {}",
|
||||
info.method(),
|
||||
info.uri().path(),
|
||||
request_id,
|
||||
res.status()
|
||||
);
|
||||
} else {
|
||||
tracing::info!(
|
||||
"{} {} {} {}",
|
||||
info.method(),
|
||||
info.uri().path(),
|
||||
request_id,
|
||||
res.status()
|
||||
);
|
||||
}
|
||||
Ok(res)
|
||||
}
|
||||
|
||||
async fn prometheus_metrics_handler(_req: Request<Body>) -> Result<Response<Body>, ApiError> {
|
||||
@@ -142,6 +96,12 @@ pub fn add_request_id_middleware<B: hyper::body::HttpBody + Send + Sync + 'stati
|
||||
request_id.to_string()
|
||||
}
|
||||
};
|
||||
|
||||
if req.method() == Method::GET {
|
||||
tracing::debug!("{} {} {}", req.method(), req.uri().path(), request_id);
|
||||
} else {
|
||||
tracing::info!("{} {} {}", req.method(), req.uri().path(), request_id);
|
||||
}
|
||||
req.set_context(RequestId(request_id));
|
||||
|
||||
Ok(req)
|
||||
@@ -165,12 +125,11 @@ async fn add_request_id_header_to_response(
|
||||
pub fn make_router() -> RouterBuilder<hyper::Body, ApiError> {
|
||||
Router::builder()
|
||||
.middleware(add_request_id_middleware())
|
||||
.middleware(Middleware::post_with_info(logger))
|
||||
.middleware(Middleware::post_with_info(
|
||||
add_request_id_header_to_response,
|
||||
))
|
||||
.get("/metrics", |r| {
|
||||
RequestSpan(prometheus_metrics_handler).handle(r)
|
||||
})
|
||||
.get("/metrics", prometheus_metrics_handler)
|
||||
.err_handler(error::handler)
|
||||
}
|
||||
|
||||
@@ -180,43 +139,40 @@ pub fn attach_openapi_ui(
|
||||
spec_mount_path: &'static str,
|
||||
ui_mount_path: &'static str,
|
||||
) -> RouterBuilder<hyper::Body, ApiError> {
|
||||
router_builder
|
||||
.get(spec_mount_path, move |r| {
|
||||
RequestSpan(move |_| async move { Ok(Response::builder().body(Body::from(spec)).unwrap()) })
|
||||
.handle(r)
|
||||
})
|
||||
.get(ui_mount_path, move |r| RequestSpan( move |_| async move {
|
||||
Ok(Response::builder().body(Body::from(format!(r#"
|
||||
<!DOCTYPE html>
|
||||
<html lang="en">
|
||||
<head>
|
||||
<title>rweb</title>
|
||||
<link href="https://cdn.jsdelivr.net/npm/swagger-ui-dist@3/swagger-ui.css" rel="stylesheet">
|
||||
</head>
|
||||
<body>
|
||||
<div id="swagger-ui"></div>
|
||||
<script src="https://cdn.jsdelivr.net/npm/swagger-ui-dist@3/swagger-ui-bundle.js" charset="UTF-8"> </script>
|
||||
<script>
|
||||
window.onload = function() {{
|
||||
const ui = SwaggerUIBundle({{
|
||||
"dom_id": "\#swagger-ui",
|
||||
presets: [
|
||||
SwaggerUIBundle.presets.apis,
|
||||
SwaggerUIBundle.SwaggerUIStandalonePreset
|
||||
],
|
||||
layout: "BaseLayout",
|
||||
deepLinking: true,
|
||||
showExtensions: true,
|
||||
showCommonExtensions: true,
|
||||
url: "{}",
|
||||
}})
|
||||
window.ui = ui;
|
||||
}};
|
||||
</script>
|
||||
</body>
|
||||
</html>
|
||||
"#, spec_mount_path))).unwrap())
|
||||
}).handle(r))
|
||||
router_builder.get(spec_mount_path, move |_| async move {
|
||||
Ok(Response::builder().body(Body::from(spec)).unwrap())
|
||||
}).get(ui_mount_path, move |_| async move {
|
||||
Ok(Response::builder().body(Body::from(format!(r#"
|
||||
<!DOCTYPE html>
|
||||
<html lang="en">
|
||||
<head>
|
||||
<title>rweb</title>
|
||||
<link href="https://cdn.jsdelivr.net/npm/swagger-ui-dist@3/swagger-ui.css" rel="stylesheet">
|
||||
</head>
|
||||
<body>
|
||||
<div id="swagger-ui"></div>
|
||||
<script src="https://cdn.jsdelivr.net/npm/swagger-ui-dist@3/swagger-ui-bundle.js" charset="UTF-8"> </script>
|
||||
<script>
|
||||
window.onload = function() {{
|
||||
const ui = SwaggerUIBundle({{
|
||||
"dom_id": "\#swagger-ui",
|
||||
presets: [
|
||||
SwaggerUIBundle.presets.apis,
|
||||
SwaggerUIBundle.SwaggerUIStandalonePreset
|
||||
],
|
||||
layout: "BaseLayout",
|
||||
deepLinking: true,
|
||||
showExtensions: true,
|
||||
showCommonExtensions: true,
|
||||
url: "{}",
|
||||
}})
|
||||
window.ui = ui;
|
||||
}};
|
||||
</script>
|
||||
</body>
|
||||
</html>
|
||||
"#, spec_mount_path))).unwrap())
|
||||
})
|
||||
}
|
||||
|
||||
fn parse_token(header_value: &str) -> Result<&str, ApiError> {
|
||||
@@ -278,7 +234,7 @@ where
|
||||
async move {
|
||||
let headers = response.headers_mut();
|
||||
if headers.contains_key(&name) {
|
||||
warn!(
|
||||
tracing::warn!(
|
||||
"{} response already contains header {:?}",
|
||||
request_info.uri(),
|
||||
&name,
|
||||
@@ -318,7 +274,7 @@ pub fn serve_thread_main<S>(
|
||||
where
|
||||
S: Future<Output = ()> + Send + Sync,
|
||||
{
|
||||
info!("Starting an HTTP endpoint at {}", listener.local_addr()?);
|
||||
tracing::info!("Starting an HTTP endpoint at {}", listener.local_addr()?);
|
||||
|
||||
// Create a Service from the router above to handle incoming requests.
|
||||
let service = RouterService::new(router_builder.build().map_err(|err| anyhow!(err))?).unwrap();
|
||||
|
||||
@@ -13,6 +13,8 @@ pub mod simple_rcu;
|
||||
pub mod vec_map;
|
||||
|
||||
pub mod bin_ser;
|
||||
pub mod postgres_backend;
|
||||
pub mod postgres_backend_async;
|
||||
|
||||
// helper functions for creating and fsyncing
|
||||
pub mod crashsafe;
|
||||
@@ -25,6 +27,9 @@ pub mod id;
|
||||
// http endpoint utils
|
||||
pub mod http;
|
||||
|
||||
// socket splitting utils
|
||||
pub mod sock_split;
|
||||
|
||||
// common log initialisation routine
|
||||
pub mod logging;
|
||||
|
||||
@@ -49,8 +54,6 @@ pub mod fs_ext;
|
||||
|
||||
pub mod history_buffer;
|
||||
|
||||
pub mod measured_stream;
|
||||
|
||||
/// use with fail::cfg("$name", "return(2000)")
|
||||
#[macro_export]
|
||||
macro_rules! failpoint_sleep_millis_async {
|
||||
|
||||
@@ -6,7 +6,6 @@ use std::ops::{Add, AddAssign};
|
||||
use std::path::Path;
|
||||
use std::str::FromStr;
|
||||
use std::sync::atomic::{AtomicU64, Ordering};
|
||||
use tracing::info;
|
||||
|
||||
use crate::seqwait::MonotonicCounter;
|
||||
|
||||
@@ -240,7 +239,6 @@ impl MonotonicCounter<Lsn> for RecordLsn {
|
||||
let new_prev = self.last;
|
||||
self.last = lsn;
|
||||
self.prev = new_prev;
|
||||
info!("advanced record lsn to {}/{}", self.last, self.prev);
|
||||
}
|
||||
fn cnt_value(&self) -> Lsn {
|
||||
self.last
|
||||
|
||||
@@ -1,77 +0,0 @@
|
||||
use pin_project_lite::pin_project;
|
||||
use std::pin::Pin;
|
||||
use std::{io, task};
|
||||
use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
|
||||
|
||||
pin_project! {
|
||||
/// This stream tracks all writes and calls user provided
|
||||
/// callback when the underlying stream is flushed.
|
||||
pub struct MeasuredStream<S, R, W> {
|
||||
#[pin]
|
||||
stream: S,
|
||||
write_count: usize,
|
||||
inc_read_count: R,
|
||||
inc_write_count: W,
|
||||
}
|
||||
}
|
||||
|
||||
impl<S, R, W> MeasuredStream<S, R, W> {
|
||||
pub fn new(stream: S, inc_read_count: R, inc_write_count: W) -> Self {
|
||||
Self {
|
||||
stream,
|
||||
write_count: 0,
|
||||
inc_read_count,
|
||||
inc_write_count,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<S: AsyncRead + Unpin, R: FnMut(usize), W> AsyncRead for MeasuredStream<S, R, W> {
|
||||
fn poll_read(
|
||||
self: Pin<&mut Self>,
|
||||
context: &mut task::Context<'_>,
|
||||
buf: &mut ReadBuf<'_>,
|
||||
) -> task::Poll<io::Result<()>> {
|
||||
let this = self.project();
|
||||
let filled = buf.filled().len();
|
||||
this.stream.poll_read(context, buf).map_ok(|()| {
|
||||
let cnt = buf.filled().len() - filled;
|
||||
// Increment the read count.
|
||||
(this.inc_read_count)(cnt);
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl<S: AsyncWrite + Unpin, R, W: FnMut(usize)> AsyncWrite for MeasuredStream<S, R, W> {
|
||||
fn poll_write(
|
||||
self: Pin<&mut Self>,
|
||||
context: &mut task::Context<'_>,
|
||||
buf: &[u8],
|
||||
) -> task::Poll<io::Result<usize>> {
|
||||
let this = self.project();
|
||||
this.stream.poll_write(context, buf).map_ok(|cnt| {
|
||||
// Increment the write count.
|
||||
*this.write_count += cnt;
|
||||
cnt
|
||||
})
|
||||
}
|
||||
|
||||
fn poll_flush(
|
||||
self: Pin<&mut Self>,
|
||||
context: &mut task::Context<'_>,
|
||||
) -> task::Poll<io::Result<()>> {
|
||||
let this = self.project();
|
||||
this.stream.poll_flush(context).map_ok(|()| {
|
||||
// Call the user provided callback and reset the write count.
|
||||
(this.inc_write_count)(*this.write_count);
|
||||
*this.write_count = 0;
|
||||
})
|
||||
}
|
||||
|
||||
fn poll_shutdown(
|
||||
self: Pin<&mut Self>,
|
||||
context: &mut task::Context<'_>,
|
||||
) -> task::Poll<io::Result<()>> {
|
||||
self.project().stream.poll_shutdown(context)
|
||||
}
|
||||
}
|
||||
485
libs/utils/src/postgres_backend.rs
Normal file
485
libs/utils/src/postgres_backend.rs
Normal file
@@ -0,0 +1,485 @@
|
||||
//! Server-side synchronous Postgres connection, as limited as we need.
|
||||
//! To use, create PostgresBackend and run() it, passing the Handler
|
||||
//! implementation determining how to process the queries. Currently its API
|
||||
//! is rather narrow, but we can extend it once required.
|
||||
|
||||
use crate::postgres_backend_async::{log_query_error, short_error, QueryError};
|
||||
use crate::sock_split::{BidiStream, ReadStream, WriteStream};
|
||||
use anyhow::Context;
|
||||
use bytes::{Bytes, BytesMut};
|
||||
use pq_proto::{BeMessage, FeMessage, FeStartupPacket};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::fmt;
|
||||
use std::io::{self, Write};
|
||||
use std::net::{Shutdown, SocketAddr, TcpStream};
|
||||
use std::str::FromStr;
|
||||
use std::sync::Arc;
|
||||
use std::time::Duration;
|
||||
use tracing::*;
|
||||
|
||||
pub trait Handler {
|
||||
/// Handle single query.
|
||||
/// postgres_backend will issue ReadyForQuery after calling this (this
|
||||
/// might be not what we want after CopyData streaming, but currently we don't
|
||||
/// care).
|
||||
fn process_query(
|
||||
&mut self,
|
||||
pgb: &mut PostgresBackend,
|
||||
query_string: &str,
|
||||
) -> Result<(), QueryError>;
|
||||
|
||||
/// Called on startup packet receival, allows to process params.
|
||||
///
|
||||
/// If Ok(false) is returned postgres_backend will skip auth -- that is needed for new users
|
||||
/// creation is the proxy code. That is quite hacky and ad-hoc solution, may be we could allow
|
||||
/// to override whole init logic in implementations.
|
||||
fn startup(
|
||||
&mut self,
|
||||
_pgb: &mut PostgresBackend,
|
||||
_sm: &FeStartupPacket,
|
||||
) -> Result<(), QueryError> {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Check auth jwt
|
||||
fn check_auth_jwt(
|
||||
&mut self,
|
||||
_pgb: &mut PostgresBackend,
|
||||
_jwt_response: &[u8],
|
||||
) -> Result<(), QueryError> {
|
||||
Err(QueryError::Other(anyhow::anyhow!("JWT auth failed")))
|
||||
}
|
||||
|
||||
fn is_shutdown_requested(&self) -> bool {
|
||||
false
|
||||
}
|
||||
}
|
||||
|
||||
/// PostgresBackend protocol state.
|
||||
/// XXX: The order of the constructors matters.
|
||||
#[derive(Clone, Copy, PartialEq, Eq, PartialOrd)]
|
||||
pub enum ProtoState {
|
||||
Initialization,
|
||||
Encrypted,
|
||||
Authentication,
|
||||
Established,
|
||||
}
|
||||
|
||||
#[derive(Debug, PartialEq, Eq, Clone, Copy, Serialize, Deserialize)]
|
||||
pub enum AuthType {
|
||||
Trust,
|
||||
// This mimics postgres's AuthenticationCleartextPassword but instead of password expects JWT
|
||||
NeonJWT,
|
||||
}
|
||||
|
||||
impl FromStr for AuthType {
|
||||
type Err = anyhow::Error;
|
||||
|
||||
fn from_str(s: &str) -> Result<Self, Self::Err> {
|
||||
match s {
|
||||
"Trust" => Ok(Self::Trust),
|
||||
"NeonJWT" => Ok(Self::NeonJWT),
|
||||
_ => anyhow::bail!("invalid value \"{s}\" for auth type"),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl fmt::Display for AuthType {
|
||||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||
f.write_str(match self {
|
||||
AuthType::Trust => "Trust",
|
||||
AuthType::NeonJWT => "NeonJWT",
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone, Copy)]
|
||||
pub enum ProcessMsgResult {
|
||||
Continue,
|
||||
Break,
|
||||
}
|
||||
|
||||
/// Always-writeable sock_split stream.
|
||||
/// May not be readable. See [`PostgresBackend::take_stream_in`]
|
||||
pub enum Stream {
|
||||
Bidirectional(BidiStream),
|
||||
WriteOnly(WriteStream),
|
||||
}
|
||||
|
||||
impl Stream {
|
||||
fn shutdown(&mut self, how: Shutdown) -> io::Result<()> {
|
||||
match self {
|
||||
Self::Bidirectional(bidi_stream) => bidi_stream.shutdown(how),
|
||||
Self::WriteOnly(write_stream) => write_stream.shutdown(how),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl io::Write for Stream {
|
||||
fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
|
||||
match self {
|
||||
Self::Bidirectional(bidi_stream) => bidi_stream.write(buf),
|
||||
Self::WriteOnly(write_stream) => write_stream.write(buf),
|
||||
}
|
||||
}
|
||||
|
||||
fn flush(&mut self) -> io::Result<()> {
|
||||
match self {
|
||||
Self::Bidirectional(bidi_stream) => bidi_stream.flush(),
|
||||
Self::WriteOnly(write_stream) => write_stream.flush(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub struct PostgresBackend {
|
||||
stream: Option<Stream>,
|
||||
// Output buffer. c.f. BeMessage::write why we are using BytesMut here.
|
||||
buf_out: BytesMut,
|
||||
|
||||
pub state: ProtoState,
|
||||
|
||||
auth_type: AuthType,
|
||||
|
||||
peer_addr: SocketAddr,
|
||||
pub tls_config: Option<Arc<rustls::ServerConfig>>,
|
||||
}
|
||||
|
||||
pub fn query_from_cstring(query_string: Bytes) -> Vec<u8> {
|
||||
let mut query_string = query_string.to_vec();
|
||||
if let Some(ch) = query_string.last() {
|
||||
if *ch == 0 {
|
||||
query_string.pop();
|
||||
}
|
||||
}
|
||||
query_string
|
||||
}
|
||||
|
||||
// Helper function for socket read loops
|
||||
pub fn is_socket_read_timed_out(error: &anyhow::Error) -> bool {
|
||||
for cause in error.chain() {
|
||||
if let Some(io_error) = cause.downcast_ref::<io::Error>() {
|
||||
if io_error.kind() == std::io::ErrorKind::WouldBlock {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
}
|
||||
false
|
||||
}
|
||||
|
||||
// Cast a byte slice to a string slice, dropping null terminator if there's one.
|
||||
fn cstr_to_str(bytes: &[u8]) -> anyhow::Result<&str> {
|
||||
let without_null = bytes.strip_suffix(&[0]).unwrap_or(bytes);
|
||||
std::str::from_utf8(without_null).map_err(|e| e.into())
|
||||
}
|
||||
|
||||
impl PostgresBackend {
|
||||
pub fn new(
|
||||
socket: TcpStream,
|
||||
auth_type: AuthType,
|
||||
tls_config: Option<Arc<rustls::ServerConfig>>,
|
||||
set_read_timeout: bool,
|
||||
) -> io::Result<Self> {
|
||||
let peer_addr = socket.peer_addr()?;
|
||||
if set_read_timeout {
|
||||
socket
|
||||
.set_read_timeout(Some(Duration::from_secs(5)))
|
||||
.unwrap();
|
||||
}
|
||||
|
||||
Ok(Self {
|
||||
stream: Some(Stream::Bidirectional(BidiStream::from_tcp(socket))),
|
||||
buf_out: BytesMut::with_capacity(10 * 1024),
|
||||
state: ProtoState::Initialization,
|
||||
auth_type,
|
||||
tls_config,
|
||||
peer_addr,
|
||||
})
|
||||
}
|
||||
|
||||
pub fn into_stream(self) -> Stream {
|
||||
self.stream.unwrap()
|
||||
}
|
||||
|
||||
/// Get direct reference (into the Option) to the read stream.
|
||||
fn get_stream_in(&mut self) -> anyhow::Result<&mut BidiStream> {
|
||||
match &mut self.stream {
|
||||
Some(Stream::Bidirectional(stream)) => Ok(stream),
|
||||
_ => anyhow::bail!("reader taken"),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn get_peer_addr(&self) -> &SocketAddr {
|
||||
&self.peer_addr
|
||||
}
|
||||
|
||||
pub fn take_stream_in(&mut self) -> Option<ReadStream> {
|
||||
let stream = self.stream.take();
|
||||
match stream {
|
||||
Some(Stream::Bidirectional(bidi_stream)) => {
|
||||
let (read, write) = bidi_stream.split();
|
||||
self.stream = Some(Stream::WriteOnly(write));
|
||||
Some(read)
|
||||
}
|
||||
stream => {
|
||||
self.stream = stream;
|
||||
None
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Read full message or return None if connection is closed.
|
||||
pub fn read_message(&mut self) -> Result<Option<FeMessage>, QueryError> {
|
||||
let (state, stream) = (self.state, self.get_stream_in()?);
|
||||
|
||||
use ProtoState::*;
|
||||
match state {
|
||||
Initialization | Encrypted => FeStartupPacket::read(stream),
|
||||
Authentication | Established => FeMessage::read(stream),
|
||||
}
|
||||
.map_err(QueryError::from)
|
||||
}
|
||||
|
||||
/// Write message into internal output buffer.
|
||||
pub fn write_message_noflush(&mut self, message: &BeMessage) -> io::Result<&mut Self> {
|
||||
BeMessage::write(&mut self.buf_out, message)?;
|
||||
Ok(self)
|
||||
}
|
||||
|
||||
/// Flush output buffer into the socket.
|
||||
pub fn flush(&mut self) -> io::Result<&mut Self> {
|
||||
let stream = self.stream.as_mut().unwrap();
|
||||
stream.write_all(&self.buf_out)?;
|
||||
self.buf_out.clear();
|
||||
Ok(self)
|
||||
}
|
||||
|
||||
/// Write message into internal buffer and flush it.
|
||||
pub fn write_message(&mut self, message: &BeMessage) -> io::Result<&mut Self> {
|
||||
self.write_message_noflush(message)?;
|
||||
self.flush()
|
||||
}
|
||||
|
||||
// Wrapper for run_message_loop() that shuts down socket when we are done
|
||||
pub fn run(mut self, handler: &mut impl Handler) -> Result<(), QueryError> {
|
||||
let ret = self.run_message_loop(handler);
|
||||
if let Some(stream) = self.stream.as_mut() {
|
||||
let _ = stream.shutdown(Shutdown::Both);
|
||||
}
|
||||
ret
|
||||
}
|
||||
|
||||
fn run_message_loop(&mut self, handler: &mut impl Handler) -> Result<(), QueryError> {
|
||||
trace!("postgres backend to {:?} started", self.peer_addr);
|
||||
|
||||
let mut unnamed_query_string = Bytes::new();
|
||||
|
||||
while !handler.is_shutdown_requested() {
|
||||
match self.read_message() {
|
||||
Ok(message) => {
|
||||
if let Some(msg) = message {
|
||||
trace!("got message {msg:?}");
|
||||
|
||||
match self.process_message(handler, msg, &mut unnamed_query_string)? {
|
||||
ProcessMsgResult::Continue => continue,
|
||||
ProcessMsgResult::Break => break,
|
||||
}
|
||||
} else {
|
||||
break;
|
||||
}
|
||||
}
|
||||
Err(e) => {
|
||||
if let QueryError::Other(e) = &e {
|
||||
if is_socket_read_timed_out(e) {
|
||||
continue;
|
||||
}
|
||||
}
|
||||
return Err(e);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
trace!("postgres backend to {:?} exited", self.peer_addr);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub fn start_tls(&mut self) -> anyhow::Result<()> {
|
||||
match self.stream.take() {
|
||||
Some(Stream::Bidirectional(bidi_stream)) => {
|
||||
let conn = rustls::ServerConnection::new(self.tls_config.clone().unwrap())?;
|
||||
self.stream = Some(Stream::Bidirectional(bidi_stream.start_tls(conn)?));
|
||||
Ok(())
|
||||
}
|
||||
stream => {
|
||||
self.stream = stream;
|
||||
anyhow::bail!("can't start TLs without bidi stream");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn process_message(
|
||||
&mut self,
|
||||
handler: &mut impl Handler,
|
||||
msg: FeMessage,
|
||||
unnamed_query_string: &mut Bytes,
|
||||
) -> Result<ProcessMsgResult, QueryError> {
|
||||
// Allow only startup and password messages during auth. Otherwise client would be able to bypass auth
|
||||
// TODO: change that to proper top-level match of protocol state with separate message handling for each state
|
||||
if self.state < ProtoState::Established
|
||||
&& !matches!(
|
||||
msg,
|
||||
FeMessage::PasswordMessage(_) | FeMessage::StartupPacket(_)
|
||||
)
|
||||
{
|
||||
return Err(QueryError::Other(anyhow::anyhow!("protocol violation")));
|
||||
}
|
||||
|
||||
let have_tls = self.tls_config.is_some();
|
||||
match msg {
|
||||
FeMessage::StartupPacket(m) => {
|
||||
trace!("got startup message {m:?}");
|
||||
|
||||
match m {
|
||||
FeStartupPacket::SslRequest => {
|
||||
debug!("SSL requested");
|
||||
|
||||
self.write_message(&BeMessage::EncryptionResponse(have_tls))?;
|
||||
if have_tls {
|
||||
self.start_tls()?;
|
||||
self.state = ProtoState::Encrypted;
|
||||
}
|
||||
}
|
||||
FeStartupPacket::GssEncRequest => {
|
||||
debug!("GSS requested");
|
||||
self.write_message(&BeMessage::EncryptionResponse(false))?;
|
||||
}
|
||||
FeStartupPacket::StartupMessage { .. } => {
|
||||
if have_tls && !matches!(self.state, ProtoState::Encrypted) {
|
||||
self.write_message(&BeMessage::ErrorResponse(
|
||||
"must connect with TLS",
|
||||
None,
|
||||
))?;
|
||||
return Err(QueryError::Other(anyhow::anyhow!(
|
||||
"client did not connect with TLS"
|
||||
)));
|
||||
}
|
||||
|
||||
// NB: startup() may change self.auth_type -- we are using that in proxy code
|
||||
// to bypass auth for new users.
|
||||
handler.startup(self, &m)?;
|
||||
|
||||
match self.auth_type {
|
||||
AuthType::Trust => {
|
||||
self.write_message_noflush(&BeMessage::AuthenticationOk)?
|
||||
.write_message_noflush(&BeMessage::CLIENT_ENCODING)?
|
||||
// The async python driver requires a valid server_version
|
||||
.write_message_noflush(&BeMessage::server_version("14.1"))?
|
||||
.write_message(&BeMessage::ReadyForQuery)?;
|
||||
self.state = ProtoState::Established;
|
||||
}
|
||||
AuthType::NeonJWT => {
|
||||
self.write_message(&BeMessage::AuthenticationCleartextPassword)?;
|
||||
self.state = ProtoState::Authentication;
|
||||
}
|
||||
}
|
||||
}
|
||||
FeStartupPacket::CancelRequest { .. } => {
|
||||
return Ok(ProcessMsgResult::Break);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
FeMessage::PasswordMessage(m) => {
|
||||
trace!("got password message '{:?}'", m);
|
||||
|
||||
assert!(self.state == ProtoState::Authentication);
|
||||
|
||||
match self.auth_type {
|
||||
AuthType::Trust => unreachable!(),
|
||||
AuthType::NeonJWT => {
|
||||
let (_, jwt_response) = m.split_last().context("protocol violation")?;
|
||||
|
||||
if let Err(e) = handler.check_auth_jwt(self, jwt_response) {
|
||||
self.write_message(&BeMessage::ErrorResponse(
|
||||
&e.to_string(),
|
||||
Some(e.pg_error_code()),
|
||||
))?;
|
||||
return Err(e);
|
||||
}
|
||||
}
|
||||
}
|
||||
self.write_message_noflush(&BeMessage::AuthenticationOk)?
|
||||
.write_message_noflush(&BeMessage::CLIENT_ENCODING)?
|
||||
.write_message(&BeMessage::ReadyForQuery)?;
|
||||
self.state = ProtoState::Established;
|
||||
}
|
||||
|
||||
FeMessage::Query(body) => {
|
||||
// remove null terminator
|
||||
let query_string = cstr_to_str(&body)?;
|
||||
|
||||
trace!("got query {query_string:?}");
|
||||
if let Err(e) = handler.process_query(self, query_string) {
|
||||
log_query_error(query_string, &e);
|
||||
let short_error = short_error(&e);
|
||||
self.write_message_noflush(&BeMessage::ErrorResponse(
|
||||
&short_error,
|
||||
Some(e.pg_error_code()),
|
||||
))?;
|
||||
}
|
||||
self.write_message(&BeMessage::ReadyForQuery)?;
|
||||
}
|
||||
|
||||
FeMessage::Parse(m) => {
|
||||
*unnamed_query_string = m.query_string;
|
||||
self.write_message(&BeMessage::ParseComplete)?;
|
||||
}
|
||||
|
||||
FeMessage::Describe(_) => {
|
||||
self.write_message_noflush(&BeMessage::ParameterDescription)?
|
||||
.write_message(&BeMessage::NoData)?;
|
||||
}
|
||||
|
||||
FeMessage::Bind(_) => {
|
||||
self.write_message(&BeMessage::BindComplete)?;
|
||||
}
|
||||
|
||||
FeMessage::Close(_) => {
|
||||
self.write_message(&BeMessage::CloseComplete)?;
|
||||
}
|
||||
|
||||
FeMessage::Execute(_) => {
|
||||
let query_string = cstr_to_str(unnamed_query_string)?;
|
||||
trace!("got execute {query_string:?}");
|
||||
if let Err(e) = handler.process_query(self, query_string) {
|
||||
log_query_error(query_string, &e);
|
||||
self.write_message(&BeMessage::ErrorResponse(
|
||||
&e.to_string(),
|
||||
Some(e.pg_error_code()),
|
||||
))?;
|
||||
}
|
||||
// NOTE there is no ReadyForQuery message. This handler is used
|
||||
// for basebackup and it uses CopyOut which doesn't require
|
||||
// ReadyForQuery message and backend just switches back to
|
||||
// processing mode after sending CopyDone or ErrorResponse.
|
||||
}
|
||||
|
||||
FeMessage::Sync => {
|
||||
self.write_message(&BeMessage::ReadyForQuery)?;
|
||||
}
|
||||
|
||||
FeMessage::Terminate => {
|
||||
return Ok(ProcessMsgResult::Break);
|
||||
}
|
||||
|
||||
// We prefer explicit pattern matching to wildcards, because
|
||||
// this helps us spot the places where new variants are missing
|
||||
FeMessage::CopyData(_) | FeMessage::CopyDone | FeMessage::CopyFail => {
|
||||
return Err(QueryError::Other(anyhow::anyhow!(
|
||||
"unexpected message type: {msg:?}"
|
||||
)));
|
||||
}
|
||||
}
|
||||
|
||||
Ok(ProcessMsgResult::Continue)
|
||||
}
|
||||
}
|
||||
634
libs/utils/src/postgres_backend_async.rs
Normal file
634
libs/utils/src/postgres_backend_async.rs
Normal file
@@ -0,0 +1,634 @@
|
||||
//! Server-side asynchronous Postgres connection, as limited as we need.
|
||||
//! To use, create PostgresBackend and run() it, passing the Handler
|
||||
//! implementation determining how to process the queries. Currently its API
|
||||
//! is rather narrow, but we can extend it once required.
|
||||
|
||||
use crate::postgres_backend::AuthType;
|
||||
use anyhow::Context;
|
||||
use bytes::{Buf, Bytes, BytesMut};
|
||||
use pq_proto::{BeMessage, ConnectionError, FeMessage, FeStartupPacket, SQLSTATE_INTERNAL_ERROR};
|
||||
use std::io;
|
||||
use std::net::SocketAddr;
|
||||
use std::pin::Pin;
|
||||
use std::sync::Arc;
|
||||
use std::task::Poll;
|
||||
use std::{future::Future, task::ready};
|
||||
use tracing::{debug, error, info, trace};
|
||||
|
||||
use tokio::io::{AsyncRead, AsyncWrite, AsyncWriteExt, BufReader};
|
||||
use tokio_rustls::TlsAcceptor;
|
||||
|
||||
pub fn is_expected_io_error(e: &io::Error) -> bool {
|
||||
use io::ErrorKind::*;
|
||||
matches!(
|
||||
e.kind(),
|
||||
ConnectionRefused | ConnectionAborted | ConnectionReset
|
||||
)
|
||||
}
|
||||
|
||||
/// An error, occurred during query processing:
|
||||
/// either during the connection ([`ConnectionError`]) or before/after it.
|
||||
#[derive(thiserror::Error, Debug)]
|
||||
pub enum QueryError {
|
||||
/// The connection was lost while processing the query.
|
||||
#[error(transparent)]
|
||||
Disconnected(#[from] ConnectionError),
|
||||
/// Some other error
|
||||
#[error(transparent)]
|
||||
Other(#[from] anyhow::Error),
|
||||
}
|
||||
|
||||
impl From<io::Error> for QueryError {
|
||||
fn from(e: io::Error) -> Self {
|
||||
Self::Disconnected(ConnectionError::Socket(e))
|
||||
}
|
||||
}
|
||||
|
||||
impl QueryError {
|
||||
pub fn pg_error_code(&self) -> &'static [u8; 5] {
|
||||
match self {
|
||||
Self::Disconnected(_) => b"08006", // connection failure
|
||||
Self::Other(_) => SQLSTATE_INTERNAL_ERROR, // internal error
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait::async_trait]
|
||||
pub trait Handler {
|
||||
/// Handle single query.
|
||||
/// postgres_backend will issue ReadyForQuery after calling this (this
|
||||
/// might be not what we want after CopyData streaming, but currently we don't
|
||||
/// care).
|
||||
async fn process_query(
|
||||
&mut self,
|
||||
pgb: &mut PostgresBackend,
|
||||
query_string: &str,
|
||||
) -> Result<(), QueryError>;
|
||||
|
||||
/// Called on startup packet receival, allows to process params.
|
||||
///
|
||||
/// If Ok(false) is returned postgres_backend will skip auth -- that is needed for new users
|
||||
/// creation is the proxy code. That is quite hacky and ad-hoc solution, may be we could allow
|
||||
/// to override whole init logic in implementations.
|
||||
fn startup(
|
||||
&mut self,
|
||||
_pgb: &mut PostgresBackend,
|
||||
_sm: &FeStartupPacket,
|
||||
) -> Result<(), QueryError> {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Check auth jwt
|
||||
fn check_auth_jwt(
|
||||
&mut self,
|
||||
_pgb: &mut PostgresBackend,
|
||||
_jwt_response: &[u8],
|
||||
) -> Result<(), QueryError> {
|
||||
Err(QueryError::Other(anyhow::anyhow!("JWT auth failed")))
|
||||
}
|
||||
}
|
||||
|
||||
/// PostgresBackend protocol state.
|
||||
/// XXX: The order of the constructors matters.
|
||||
#[derive(Clone, Copy, PartialEq, Eq, PartialOrd)]
|
||||
pub enum ProtoState {
|
||||
Initialization,
|
||||
Encrypted,
|
||||
Authentication,
|
||||
Established,
|
||||
Closed,
|
||||
}
|
||||
|
||||
#[derive(Clone, Copy)]
|
||||
pub enum ProcessMsgResult {
|
||||
Continue,
|
||||
Break,
|
||||
}
|
||||
|
||||
/// Always-writeable sock_split stream.
|
||||
/// May not be readable. See [`PostgresBackend::take_stream_in`]
|
||||
pub enum Stream {
|
||||
Unencrypted(BufReader<tokio::net::TcpStream>),
|
||||
Tls(Box<tokio_rustls::server::TlsStream<BufReader<tokio::net::TcpStream>>>),
|
||||
Broken,
|
||||
}
|
||||
|
||||
impl AsyncWrite for Stream {
|
||||
fn poll_write(
|
||||
self: Pin<&mut Self>,
|
||||
cx: &mut std::task::Context<'_>,
|
||||
buf: &[u8],
|
||||
) -> Poll<io::Result<usize>> {
|
||||
match self.get_mut() {
|
||||
Self::Unencrypted(stream) => Pin::new(stream).poll_write(cx, buf),
|
||||
Self::Tls(stream) => Pin::new(stream).poll_write(cx, buf),
|
||||
Self::Broken => unreachable!(),
|
||||
}
|
||||
}
|
||||
fn poll_flush(self: Pin<&mut Self>, cx: &mut std::task::Context<'_>) -> Poll<io::Result<()>> {
|
||||
match self.get_mut() {
|
||||
Self::Unencrypted(stream) => Pin::new(stream).poll_flush(cx),
|
||||
Self::Tls(stream) => Pin::new(stream).poll_flush(cx),
|
||||
Self::Broken => unreachable!(),
|
||||
}
|
||||
}
|
||||
fn poll_shutdown(
|
||||
self: Pin<&mut Self>,
|
||||
cx: &mut std::task::Context<'_>,
|
||||
) -> Poll<io::Result<()>> {
|
||||
match self.get_mut() {
|
||||
Self::Unencrypted(stream) => Pin::new(stream).poll_shutdown(cx),
|
||||
Self::Tls(stream) => Pin::new(stream).poll_shutdown(cx),
|
||||
Self::Broken => unreachable!(),
|
||||
}
|
||||
}
|
||||
}
|
||||
impl AsyncRead for Stream {
|
||||
fn poll_read(
|
||||
self: Pin<&mut Self>,
|
||||
cx: &mut std::task::Context<'_>,
|
||||
buf: &mut tokio::io::ReadBuf<'_>,
|
||||
) -> Poll<io::Result<()>> {
|
||||
match self.get_mut() {
|
||||
Self::Unencrypted(stream) => Pin::new(stream).poll_read(cx, buf),
|
||||
Self::Tls(stream) => Pin::new(stream).poll_read(cx, buf),
|
||||
Self::Broken => unreachable!(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub struct PostgresBackend {
|
||||
stream: Stream,
|
||||
|
||||
// Output buffer. c.f. BeMessage::write why we are using BytesMut here.
|
||||
// The data between 0 and "current position" as tracked by the bytes::Buf
|
||||
// implementation of BytesMut, have already been written.
|
||||
buf_out: BytesMut,
|
||||
|
||||
pub state: ProtoState,
|
||||
|
||||
auth_type: AuthType,
|
||||
|
||||
peer_addr: SocketAddr,
|
||||
pub tls_config: Option<Arc<rustls::ServerConfig>>,
|
||||
}
|
||||
|
||||
pub fn query_from_cstring(query_string: Bytes) -> Vec<u8> {
|
||||
let mut query_string = query_string.to_vec();
|
||||
if let Some(ch) = query_string.last() {
|
||||
if *ch == 0 {
|
||||
query_string.pop();
|
||||
}
|
||||
}
|
||||
query_string
|
||||
}
|
||||
|
||||
// Cast a byte slice to a string slice, dropping null terminator if there's one.
|
||||
fn cstr_to_str(bytes: &[u8]) -> anyhow::Result<&str> {
|
||||
let without_null = bytes.strip_suffix(&[0]).unwrap_or(bytes);
|
||||
std::str::from_utf8(without_null).map_err(|e| e.into())
|
||||
}
|
||||
|
||||
impl PostgresBackend {
|
||||
pub fn new(
|
||||
socket: tokio::net::TcpStream,
|
||||
auth_type: AuthType,
|
||||
tls_config: Option<Arc<rustls::ServerConfig>>,
|
||||
) -> io::Result<Self> {
|
||||
let peer_addr = socket.peer_addr()?;
|
||||
|
||||
Ok(Self {
|
||||
stream: Stream::Unencrypted(BufReader::new(socket)),
|
||||
buf_out: BytesMut::with_capacity(10 * 1024),
|
||||
state: ProtoState::Initialization,
|
||||
auth_type,
|
||||
tls_config,
|
||||
peer_addr,
|
||||
})
|
||||
}
|
||||
|
||||
pub fn get_peer_addr(&self) -> &SocketAddr {
|
||||
&self.peer_addr
|
||||
}
|
||||
|
||||
/// Read full message or return None if connection is closed.
|
||||
pub async fn read_message(&mut self) -> Result<Option<FeMessage>, QueryError> {
|
||||
use ProtoState::*;
|
||||
match self.state {
|
||||
Initialization | Encrypted => FeStartupPacket::read_fut(&mut self.stream).await,
|
||||
Authentication | Established => FeMessage::read_fut(&mut self.stream).await,
|
||||
Closed => Ok(None),
|
||||
}
|
||||
.map_err(QueryError::from)
|
||||
}
|
||||
|
||||
/// Flush output buffer into the socket.
|
||||
pub async fn flush(&mut self) -> io::Result<()> {
|
||||
while self.buf_out.has_remaining() {
|
||||
let bytes_written = self.stream.write(self.buf_out.chunk()).await?;
|
||||
self.buf_out.advance(bytes_written);
|
||||
}
|
||||
self.buf_out.clear();
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Write message into internal output buffer.
|
||||
pub fn write_message(&mut self, message: &BeMessage<'_>) -> io::Result<&mut Self> {
|
||||
BeMessage::write(&mut self.buf_out, message)?;
|
||||
Ok(self)
|
||||
}
|
||||
|
||||
/// Returns an AsyncWrite implementation that wraps all the data written
|
||||
/// to it in CopyData messages, and writes them to the connection
|
||||
///
|
||||
/// The caller is responsible for sending CopyOutResponse and CopyDone messages.
|
||||
pub fn copyout_writer(&mut self) -> CopyDataWriter {
|
||||
CopyDataWriter { pgb: self }
|
||||
}
|
||||
|
||||
/// A polling function that tries to write all the data from 'buf_out' to the
|
||||
/// underlying stream.
|
||||
fn poll_write_buf(
|
||||
&mut self,
|
||||
cx: &mut std::task::Context<'_>,
|
||||
) -> Poll<Result<(), std::io::Error>> {
|
||||
while self.buf_out.has_remaining() {
|
||||
match ready!(Pin::new(&mut self.stream).poll_write(cx, self.buf_out.chunk())) {
|
||||
Ok(bytes_written) => self.buf_out.advance(bytes_written),
|
||||
Err(err) => return Poll::Ready(Err(err)),
|
||||
}
|
||||
}
|
||||
Poll::Ready(Ok(()))
|
||||
}
|
||||
|
||||
fn poll_flush(&mut self, cx: &mut std::task::Context<'_>) -> Poll<Result<(), std::io::Error>> {
|
||||
Pin::new(&mut self.stream).poll_flush(cx)
|
||||
}
|
||||
|
||||
// Wrapper for run_message_loop() that shuts down socket when we are done
|
||||
pub async fn run<F, S>(
|
||||
mut self,
|
||||
handler: &mut impl Handler,
|
||||
shutdown_watcher: F,
|
||||
) -> Result<(), QueryError>
|
||||
where
|
||||
F: Fn() -> S,
|
||||
S: Future,
|
||||
{
|
||||
let ret = self.run_message_loop(handler, shutdown_watcher).await;
|
||||
let _ = self.stream.shutdown();
|
||||
ret
|
||||
}
|
||||
|
||||
async fn run_message_loop<F, S>(
|
||||
&mut self,
|
||||
handler: &mut impl Handler,
|
||||
shutdown_watcher: F,
|
||||
) -> Result<(), QueryError>
|
||||
where
|
||||
F: Fn() -> S,
|
||||
S: Future,
|
||||
{
|
||||
trace!("postgres backend to {:?} started", self.peer_addr);
|
||||
|
||||
tokio::select!(
|
||||
biased;
|
||||
|
||||
_ = shutdown_watcher() => {
|
||||
// We were requested to shut down.
|
||||
tracing::info!("shutdown request received during handshake");
|
||||
return Ok(())
|
||||
},
|
||||
|
||||
result = async {
|
||||
while self.state < ProtoState::Established {
|
||||
if let Some(msg) = self.read_message().await? {
|
||||
trace!("got message {msg:?} during handshake");
|
||||
|
||||
match self.process_handshake_message(handler, msg).await? {
|
||||
ProcessMsgResult::Continue => {
|
||||
self.flush().await?;
|
||||
continue;
|
||||
}
|
||||
ProcessMsgResult::Break => {
|
||||
trace!("postgres backend to {:?} exited during handshake", self.peer_addr);
|
||||
return Ok(());
|
||||
}
|
||||
}
|
||||
} else {
|
||||
trace!("postgres backend to {:?} exited during handshake", self.peer_addr);
|
||||
return Ok(());
|
||||
}
|
||||
}
|
||||
Ok::<(), QueryError>(())
|
||||
} => {
|
||||
// Handshake complete.
|
||||
result?;
|
||||
}
|
||||
);
|
||||
|
||||
// Authentication completed
|
||||
let mut query_string = Bytes::new();
|
||||
while let Some(msg) = tokio::select!(
|
||||
biased;
|
||||
_ = shutdown_watcher() => {
|
||||
// We were requested to shut down.
|
||||
tracing::info!("shutdown request received in run_message_loop");
|
||||
Ok(None)
|
||||
},
|
||||
msg = self.read_message() => { msg },
|
||||
)? {
|
||||
trace!("got message {:?}", msg);
|
||||
|
||||
let result = self.process_message(handler, msg, &mut query_string).await;
|
||||
self.flush().await?;
|
||||
match result? {
|
||||
ProcessMsgResult::Continue => {
|
||||
self.flush().await?;
|
||||
continue;
|
||||
}
|
||||
ProcessMsgResult::Break => break,
|
||||
}
|
||||
}
|
||||
|
||||
trace!("postgres backend to {:?} exited", self.peer_addr);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn start_tls(&mut self) -> anyhow::Result<()> {
|
||||
if let Stream::Unencrypted(plain_stream) =
|
||||
std::mem::replace(&mut self.stream, Stream::Broken)
|
||||
{
|
||||
let acceptor = TlsAcceptor::from(self.tls_config.clone().unwrap());
|
||||
let tls_stream = acceptor.accept(plain_stream).await?;
|
||||
|
||||
self.stream = Stream::Tls(Box::new(tls_stream));
|
||||
return Ok(());
|
||||
};
|
||||
anyhow::bail!("TLS already started");
|
||||
}
|
||||
|
||||
async fn process_handshake_message(
|
||||
&mut self,
|
||||
handler: &mut impl Handler,
|
||||
msg: FeMessage,
|
||||
) -> Result<ProcessMsgResult, QueryError> {
|
||||
assert!(self.state < ProtoState::Established);
|
||||
let have_tls = self.tls_config.is_some();
|
||||
match msg {
|
||||
FeMessage::StartupPacket(m) => {
|
||||
trace!("got startup message {m:?}");
|
||||
|
||||
match m {
|
||||
FeStartupPacket::SslRequest => {
|
||||
debug!("SSL requested");
|
||||
|
||||
self.write_message(&BeMessage::EncryptionResponse(have_tls))?;
|
||||
if have_tls {
|
||||
self.start_tls().await?;
|
||||
self.state = ProtoState::Encrypted;
|
||||
}
|
||||
}
|
||||
FeStartupPacket::GssEncRequest => {
|
||||
debug!("GSS requested");
|
||||
self.write_message(&BeMessage::EncryptionResponse(false))?;
|
||||
}
|
||||
FeStartupPacket::StartupMessage { .. } => {
|
||||
if have_tls && !matches!(self.state, ProtoState::Encrypted) {
|
||||
self.write_message(&BeMessage::ErrorResponse(
|
||||
"must connect with TLS",
|
||||
None,
|
||||
))?;
|
||||
return Err(QueryError::Other(anyhow::anyhow!(
|
||||
"client did not connect with TLS"
|
||||
)));
|
||||
}
|
||||
|
||||
// NB: startup() may change self.auth_type -- we are using that in proxy code
|
||||
// to bypass auth for new users.
|
||||
handler.startup(self, &m)?;
|
||||
|
||||
match self.auth_type {
|
||||
AuthType::Trust => {
|
||||
self.write_message(&BeMessage::AuthenticationOk)?
|
||||
.write_message(&BeMessage::CLIENT_ENCODING)?
|
||||
// The async python driver requires a valid server_version
|
||||
.write_message(&BeMessage::server_version("14.1"))?
|
||||
.write_message(&BeMessage::ReadyForQuery)?;
|
||||
self.state = ProtoState::Established;
|
||||
}
|
||||
AuthType::NeonJWT => {
|
||||
self.write_message(&BeMessage::AuthenticationCleartextPassword)?;
|
||||
self.state = ProtoState::Authentication;
|
||||
}
|
||||
}
|
||||
}
|
||||
FeStartupPacket::CancelRequest { .. } => {
|
||||
self.state = ProtoState::Closed;
|
||||
return Ok(ProcessMsgResult::Break);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
FeMessage::PasswordMessage(m) => {
|
||||
trace!("got password message '{:?}'", m);
|
||||
|
||||
assert!(self.state == ProtoState::Authentication);
|
||||
|
||||
match self.auth_type {
|
||||
AuthType::Trust => unreachable!(),
|
||||
AuthType::NeonJWT => {
|
||||
let (_, jwt_response) = m.split_last().context("protocol violation")?;
|
||||
|
||||
if let Err(e) = handler.check_auth_jwt(self, jwt_response) {
|
||||
self.write_message(&BeMessage::ErrorResponse(
|
||||
&e.to_string(),
|
||||
Some(e.pg_error_code()),
|
||||
))?;
|
||||
return Err(e);
|
||||
}
|
||||
}
|
||||
}
|
||||
self.write_message(&BeMessage::AuthenticationOk)?
|
||||
.write_message(&BeMessage::CLIENT_ENCODING)?
|
||||
.write_message(&BeMessage::ReadyForQuery)?;
|
||||
self.state = ProtoState::Established;
|
||||
}
|
||||
|
||||
_ => {
|
||||
self.state = ProtoState::Closed;
|
||||
return Ok(ProcessMsgResult::Break);
|
||||
}
|
||||
}
|
||||
Ok(ProcessMsgResult::Continue)
|
||||
}
|
||||
|
||||
async fn process_message(
|
||||
&mut self,
|
||||
handler: &mut impl Handler,
|
||||
msg: FeMessage,
|
||||
unnamed_query_string: &mut Bytes,
|
||||
) -> Result<ProcessMsgResult, QueryError> {
|
||||
// Allow only startup and password messages during auth. Otherwise client would be able to bypass auth
|
||||
// TODO: change that to proper top-level match of protocol state with separate message handling for each state
|
||||
assert!(self.state == ProtoState::Established);
|
||||
|
||||
match msg {
|
||||
FeMessage::StartupPacket(_) | FeMessage::PasswordMessage(_) => {
|
||||
return Err(QueryError::Other(anyhow::anyhow!("protocol violation")));
|
||||
}
|
||||
|
||||
FeMessage::Query(body) => {
|
||||
// remove null terminator
|
||||
let query_string = cstr_to_str(&body)?;
|
||||
|
||||
trace!("got query {query_string:?}");
|
||||
if let Err(e) = handler.process_query(self, query_string).await {
|
||||
log_query_error(query_string, &e);
|
||||
let short_error = short_error(&e);
|
||||
self.write_message(&BeMessage::ErrorResponse(
|
||||
&short_error,
|
||||
Some(e.pg_error_code()),
|
||||
))?;
|
||||
}
|
||||
self.write_message(&BeMessage::ReadyForQuery)?;
|
||||
}
|
||||
|
||||
FeMessage::Parse(m) => {
|
||||
*unnamed_query_string = m.query_string;
|
||||
self.write_message(&BeMessage::ParseComplete)?;
|
||||
}
|
||||
|
||||
FeMessage::Describe(_) => {
|
||||
self.write_message(&BeMessage::ParameterDescription)?
|
||||
.write_message(&BeMessage::NoData)?;
|
||||
}
|
||||
|
||||
FeMessage::Bind(_) => {
|
||||
self.write_message(&BeMessage::BindComplete)?;
|
||||
}
|
||||
|
||||
FeMessage::Close(_) => {
|
||||
self.write_message(&BeMessage::CloseComplete)?;
|
||||
}
|
||||
|
||||
FeMessage::Execute(_) => {
|
||||
let query_string = cstr_to_str(unnamed_query_string)?;
|
||||
trace!("got execute {query_string:?}");
|
||||
if let Err(e) = handler.process_query(self, query_string).await {
|
||||
log_query_error(query_string, &e);
|
||||
self.write_message(&BeMessage::ErrorResponse(
|
||||
&e.to_string(),
|
||||
Some(e.pg_error_code()),
|
||||
))?;
|
||||
}
|
||||
// NOTE there is no ReadyForQuery message. This handler is used
|
||||
// for basebackup and it uses CopyOut which doesn't require
|
||||
// ReadyForQuery message and backend just switches back to
|
||||
// processing mode after sending CopyDone or ErrorResponse.
|
||||
}
|
||||
|
||||
FeMessage::Sync => {
|
||||
self.write_message(&BeMessage::ReadyForQuery)?;
|
||||
}
|
||||
|
||||
FeMessage::Terminate => {
|
||||
return Ok(ProcessMsgResult::Break);
|
||||
}
|
||||
|
||||
// We prefer explicit pattern matching to wildcards, because
|
||||
// this helps us spot the places where new variants are missing
|
||||
FeMessage::CopyData(_) | FeMessage::CopyDone | FeMessage::CopyFail => {
|
||||
return Err(QueryError::Other(anyhow::anyhow!(
|
||||
"unexpected message type: {:?}",
|
||||
msg
|
||||
)));
|
||||
}
|
||||
}
|
||||
|
||||
Ok(ProcessMsgResult::Continue)
|
||||
}
|
||||
}
|
||||
|
||||
///
|
||||
/// A futures::AsyncWrite implementation that wraps all data written to it in CopyData
|
||||
/// messages.
|
||||
///
|
||||
|
||||
pub struct CopyDataWriter<'a> {
|
||||
pgb: &'a mut PostgresBackend,
|
||||
}
|
||||
|
||||
impl<'a> AsyncWrite for CopyDataWriter<'a> {
|
||||
fn poll_write(
|
||||
self: Pin<&mut Self>,
|
||||
cx: &mut std::task::Context<'_>,
|
||||
buf: &[u8],
|
||||
) -> Poll<Result<usize, std::io::Error>> {
|
||||
let this = self.get_mut();
|
||||
|
||||
// It's not strictly required to flush between each message, but makes it easier
|
||||
// to view in wireshark, and usually the messages that the callers write are
|
||||
// decently-sized anyway.
|
||||
match ready!(this.pgb.poll_write_buf(cx)) {
|
||||
Ok(()) => {}
|
||||
Err(err) => return Poll::Ready(Err(err)),
|
||||
}
|
||||
|
||||
// CopyData
|
||||
// XXX: if the input is large, we should split it into multiple messages.
|
||||
// Not sure what the threshold should be, but the ultimate hard limit is that
|
||||
// the length cannot exceed u32.
|
||||
this.pgb.write_message(&BeMessage::CopyData(buf))?;
|
||||
|
||||
Poll::Ready(Ok(buf.len()))
|
||||
}
|
||||
|
||||
fn poll_flush(
|
||||
self: Pin<&mut Self>,
|
||||
cx: &mut std::task::Context<'_>,
|
||||
) -> Poll<Result<(), std::io::Error>> {
|
||||
let this = self.get_mut();
|
||||
match ready!(this.pgb.poll_write_buf(cx)) {
|
||||
Ok(()) => {}
|
||||
Err(err) => return Poll::Ready(Err(err)),
|
||||
}
|
||||
this.pgb.poll_flush(cx)
|
||||
}
|
||||
fn poll_shutdown(
|
||||
self: Pin<&mut Self>,
|
||||
cx: &mut std::task::Context<'_>,
|
||||
) -> Poll<Result<(), std::io::Error>> {
|
||||
let this = self.get_mut();
|
||||
match ready!(this.pgb.poll_write_buf(cx)) {
|
||||
Ok(()) => {}
|
||||
Err(err) => return Poll::Ready(Err(err)),
|
||||
}
|
||||
this.pgb.poll_flush(cx)
|
||||
}
|
||||
}
|
||||
|
||||
pub fn short_error(e: &QueryError) -> String {
|
||||
match e {
|
||||
QueryError::Disconnected(connection_error) => connection_error.to_string(),
|
||||
QueryError::Other(e) => format!("{e:#}"),
|
||||
}
|
||||
}
|
||||
|
||||
pub(super) fn log_query_error(query: &str, e: &QueryError) {
|
||||
match e {
|
||||
QueryError::Disconnected(ConnectionError::Socket(io_error)) => {
|
||||
if is_expected_io_error(io_error) {
|
||||
info!("query handler for '{query}' failed with expected io error: {io_error}");
|
||||
} else {
|
||||
error!("query handler for '{query}' failed with io error: {io_error}");
|
||||
}
|
||||
}
|
||||
QueryError::Disconnected(other_connection_error) => {
|
||||
error!("query handler for '{query}' failed with connection error: {other_connection_error:?}")
|
||||
}
|
||||
QueryError::Other(e) => {
|
||||
error!("query handler for '{query}' failed: {e:?}");
|
||||
}
|
||||
}
|
||||
}
|
||||
206
libs/utils/src/sock_split.rs
Normal file
206
libs/utils/src/sock_split.rs
Normal file
@@ -0,0 +1,206 @@
|
||||
use std::{
|
||||
io::{self, BufReader, Write},
|
||||
net::{Shutdown, TcpStream},
|
||||
sync::Arc,
|
||||
};
|
||||
|
||||
use rustls::Connection;
|
||||
|
||||
/// Wrapper supporting reads of a shared TcpStream.
|
||||
pub struct ArcTcpRead(Arc<TcpStream>);
|
||||
|
||||
impl io::Read for ArcTcpRead {
|
||||
fn read(&mut self, buf: &mut [u8]) -> std::io::Result<usize> {
|
||||
(&*self.0).read(buf)
|
||||
}
|
||||
}
|
||||
|
||||
impl std::ops::Deref for ArcTcpRead {
|
||||
type Target = TcpStream;
|
||||
|
||||
fn deref(&self) -> &Self::Target {
|
||||
self.0.deref()
|
||||
}
|
||||
}
|
||||
|
||||
/// Wrapper around a TCP Stream supporting buffered reads.
|
||||
pub struct BufStream(BufReader<ArcTcpRead>);
|
||||
|
||||
impl io::Read for BufStream {
|
||||
fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
|
||||
self.0.read(buf)
|
||||
}
|
||||
}
|
||||
|
||||
impl io::Write for BufStream {
|
||||
fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
|
||||
self.get_ref().write(buf)
|
||||
}
|
||||
|
||||
fn flush(&mut self) -> io::Result<()> {
|
||||
self.get_ref().flush()
|
||||
}
|
||||
}
|
||||
|
||||
impl BufStream {
|
||||
/// Unwrap into the internal BufReader.
|
||||
fn into_reader(self) -> BufReader<ArcTcpRead> {
|
||||
self.0
|
||||
}
|
||||
|
||||
/// Returns a reference to the underlying TcpStream.
|
||||
fn get_ref(&self) -> &TcpStream {
|
||||
&self.0.get_ref().0
|
||||
}
|
||||
}
|
||||
|
||||
pub enum ReadStream {
|
||||
Tcp(BufReader<ArcTcpRead>),
|
||||
Tls(rustls_split::ReadHalf),
|
||||
}
|
||||
|
||||
impl io::Read for ReadStream {
|
||||
fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
|
||||
match self {
|
||||
Self::Tcp(reader) => reader.read(buf),
|
||||
Self::Tls(read_half) => read_half.read(buf),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl ReadStream {
|
||||
pub fn shutdown(&mut self, how: Shutdown) -> io::Result<()> {
|
||||
match self {
|
||||
Self::Tcp(stream) => stream.get_ref().shutdown(how),
|
||||
Self::Tls(write_half) => write_half.shutdown(how),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub enum WriteStream {
|
||||
Tcp(Arc<TcpStream>),
|
||||
Tls(rustls_split::WriteHalf),
|
||||
}
|
||||
|
||||
impl WriteStream {
|
||||
pub fn shutdown(&mut self, how: Shutdown) -> io::Result<()> {
|
||||
match self {
|
||||
Self::Tcp(stream) => stream.shutdown(how),
|
||||
Self::Tls(write_half) => write_half.shutdown(how),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl io::Write for WriteStream {
|
||||
fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
|
||||
match self {
|
||||
Self::Tcp(stream) => stream.as_ref().write(buf),
|
||||
Self::Tls(write_half) => write_half.write(buf),
|
||||
}
|
||||
}
|
||||
|
||||
fn flush(&mut self) -> io::Result<()> {
|
||||
match self {
|
||||
Self::Tcp(stream) => stream.as_ref().flush(),
|
||||
Self::Tls(write_half) => write_half.flush(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
type TlsStream<T> = rustls::StreamOwned<rustls::ServerConnection, T>;
|
||||
|
||||
pub enum BidiStream {
|
||||
Tcp(BufStream),
|
||||
/// This variant is boxed, because [`rustls::ServerConnection`] is quite larger than [`BufStream`].
|
||||
Tls(Box<TlsStream<BufStream>>),
|
||||
}
|
||||
|
||||
impl BidiStream {
|
||||
pub fn from_tcp(stream: TcpStream) -> Self {
|
||||
Self::Tcp(BufStream(BufReader::new(ArcTcpRead(Arc::new(stream)))))
|
||||
}
|
||||
|
||||
pub fn shutdown(&mut self, how: Shutdown) -> io::Result<()> {
|
||||
match self {
|
||||
Self::Tcp(stream) => stream.get_ref().shutdown(how),
|
||||
Self::Tls(tls_boxed) => {
|
||||
if how == Shutdown::Read {
|
||||
tls_boxed.sock.get_ref().shutdown(how)
|
||||
} else {
|
||||
tls_boxed.conn.send_close_notify();
|
||||
let res = tls_boxed.flush();
|
||||
tls_boxed.sock.get_ref().shutdown(how)?;
|
||||
res
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Split the bi-directional stream into two owned read and write halves.
|
||||
pub fn split(self) -> (ReadStream, WriteStream) {
|
||||
match self {
|
||||
Self::Tcp(stream) => {
|
||||
let reader = stream.into_reader();
|
||||
let stream: Arc<TcpStream> = reader.get_ref().0.clone();
|
||||
|
||||
(ReadStream::Tcp(reader), WriteStream::Tcp(stream))
|
||||
}
|
||||
Self::Tls(tls_boxed) => {
|
||||
let reader = tls_boxed.sock.into_reader();
|
||||
let buffer_data = reader.buffer().to_owned();
|
||||
let read_buf_cfg = rustls_split::BufCfg::with_data(buffer_data, 8192);
|
||||
let write_buf_cfg = rustls_split::BufCfg::with_capacity(8192);
|
||||
|
||||
// TODO would be nice to avoid the Arc here
|
||||
let socket = Arc::try_unwrap(reader.into_inner().0).unwrap();
|
||||
|
||||
let (read_half, write_half) = rustls_split::split(
|
||||
socket,
|
||||
Connection::Server(tls_boxed.conn),
|
||||
read_buf_cfg,
|
||||
write_buf_cfg,
|
||||
);
|
||||
(ReadStream::Tls(read_half), WriteStream::Tls(write_half))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub fn start_tls(self, mut conn: rustls::ServerConnection) -> io::Result<Self> {
|
||||
match self {
|
||||
Self::Tcp(mut stream) => {
|
||||
conn.complete_io(&mut stream)?;
|
||||
assert!(!conn.is_handshaking());
|
||||
Ok(Self::Tls(Box::new(TlsStream::new(conn, stream))))
|
||||
}
|
||||
Self::Tls { .. } => Err(io::Error::new(
|
||||
io::ErrorKind::InvalidInput,
|
||||
"TLS is already started on this stream",
|
||||
)),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl io::Read for BidiStream {
|
||||
fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
|
||||
match self {
|
||||
Self::Tcp(stream) => stream.read(buf),
|
||||
Self::Tls(tls_boxed) => tls_boxed.read(buf),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl io::Write for BidiStream {
|
||||
fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
|
||||
match self {
|
||||
Self::Tcp(stream) => stream.write(buf),
|
||||
Self::Tls(tls_boxed) => tls_boxed.write(buf),
|
||||
}
|
||||
}
|
||||
|
||||
fn flush(&mut self) -> io::Result<()> {
|
||||
match self {
|
||||
Self::Tcp(stream) => stream.flush(),
|
||||
Self::Tls(tls_boxed) => tls_boxed.flush(),
|
||||
}
|
||||
}
|
||||
}
|
||||
238
libs/utils/tests/ssl_test.rs
Normal file
238
libs/utils/tests/ssl_test.rs
Normal file
@@ -0,0 +1,238 @@
|
||||
use std::{
|
||||
collections::HashMap,
|
||||
io::{Cursor, Read, Write},
|
||||
net::{TcpListener, TcpStream},
|
||||
sync::Arc,
|
||||
};
|
||||
|
||||
use byteorder::{BigEndian, ReadBytesExt, WriteBytesExt};
|
||||
use bytes::{Buf, BufMut, Bytes, BytesMut};
|
||||
use once_cell::sync::Lazy;
|
||||
|
||||
use utils::{
|
||||
postgres_backend::{AuthType, Handler, PostgresBackend},
|
||||
postgres_backend_async::QueryError,
|
||||
};
|
||||
|
||||
fn make_tcp_pair() -> (TcpStream, TcpStream) {
|
||||
let listener = TcpListener::bind("127.0.0.1:0").unwrap();
|
||||
let addr = listener.local_addr().unwrap();
|
||||
let client_stream = TcpStream::connect(addr).unwrap();
|
||||
let (server_stream, _) = listener.accept().unwrap();
|
||||
(server_stream, client_stream)
|
||||
}
|
||||
|
||||
static KEY: Lazy<rustls::PrivateKey> = Lazy::new(|| {
|
||||
let mut cursor = Cursor::new(include_bytes!("key.pem"));
|
||||
rustls::PrivateKey(rustls_pemfile::rsa_private_keys(&mut cursor).unwrap()[0].clone())
|
||||
});
|
||||
|
||||
static CERT: Lazy<rustls::Certificate> = Lazy::new(|| {
|
||||
let mut cursor = Cursor::new(include_bytes!("cert.pem"));
|
||||
rustls::Certificate(rustls_pemfile::certs(&mut cursor).unwrap()[0].clone())
|
||||
});
|
||||
|
||||
#[test]
|
||||
// [false-positive](https://github.com/rust-lang/rust-clippy/issues/9274),
|
||||
// we resize the vector so doing some modifications after all
|
||||
#[allow(clippy::read_zero_byte_vec)]
|
||||
fn ssl() {
|
||||
let (mut client_sock, server_sock) = make_tcp_pair();
|
||||
|
||||
const QUERY: &str = "hello world";
|
||||
|
||||
let client_jh = std::thread::spawn(move || {
|
||||
// SSLRequest
|
||||
client_sock.write_u32::<BigEndian>(8).unwrap();
|
||||
client_sock.write_u32::<BigEndian>(80877103).unwrap();
|
||||
|
||||
let ssl_response = client_sock.read_u8().unwrap();
|
||||
assert_eq!(b'S', ssl_response);
|
||||
|
||||
let cfg = rustls::ClientConfig::builder()
|
||||
.with_safe_defaults()
|
||||
.with_root_certificates({
|
||||
let mut store = rustls::RootCertStore::empty();
|
||||
store.add(&CERT).unwrap();
|
||||
store
|
||||
})
|
||||
.with_no_client_auth();
|
||||
let client_config = Arc::new(cfg);
|
||||
|
||||
let dns_name = "localhost".try_into().unwrap();
|
||||
let mut conn = rustls::ClientConnection::new(client_config, dns_name).unwrap();
|
||||
|
||||
conn.complete_io(&mut client_sock).unwrap();
|
||||
assert!(!conn.is_handshaking());
|
||||
|
||||
let mut stream = rustls::Stream::new(&mut conn, &mut client_sock);
|
||||
|
||||
// StartupMessage
|
||||
stream.write_u32::<BigEndian>(9).unwrap();
|
||||
stream.write_u32::<BigEndian>(196608).unwrap();
|
||||
stream.write_u8(0).unwrap();
|
||||
stream.flush().unwrap();
|
||||
|
||||
// wait for ReadyForQuery
|
||||
let mut msg_buf = Vec::new();
|
||||
loop {
|
||||
let msg = stream.read_u8().unwrap();
|
||||
let size = stream.read_u32::<BigEndian>().unwrap() - 4;
|
||||
msg_buf.resize(size as usize, 0);
|
||||
stream.read_exact(&mut msg_buf).unwrap();
|
||||
|
||||
if msg == b'Z' {
|
||||
// ReadyForQuery
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
// Query
|
||||
stream.write_u8(b'Q').unwrap();
|
||||
stream
|
||||
.write_u32::<BigEndian>(4u32 + QUERY.len() as u32)
|
||||
.unwrap();
|
||||
stream.write_all(QUERY.as_ref()).unwrap();
|
||||
stream.flush().unwrap();
|
||||
|
||||
// ReadyForQuery
|
||||
let msg = stream.read_u8().unwrap();
|
||||
assert_eq!(msg, b'Z');
|
||||
});
|
||||
|
||||
struct TestHandler {
|
||||
got_query: bool,
|
||||
}
|
||||
impl Handler for TestHandler {
|
||||
fn process_query(
|
||||
&mut self,
|
||||
_pgb: &mut PostgresBackend,
|
||||
query_string: &str,
|
||||
) -> Result<(), QueryError> {
|
||||
self.got_query = query_string == QUERY;
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
let mut handler = TestHandler { got_query: false };
|
||||
|
||||
let cfg = rustls::ServerConfig::builder()
|
||||
.with_safe_defaults()
|
||||
.with_no_client_auth()
|
||||
.with_single_cert(vec![CERT.clone()], KEY.clone())
|
||||
.unwrap();
|
||||
let tls_config = Some(Arc::new(cfg));
|
||||
|
||||
let pgb = PostgresBackend::new(server_sock, AuthType::Trust, tls_config, true).unwrap();
|
||||
pgb.run(&mut handler).unwrap();
|
||||
assert!(handler.got_query);
|
||||
|
||||
client_jh.join().unwrap();
|
||||
|
||||
// TODO consider shutdown behavior
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn no_ssl() {
|
||||
let (mut client_sock, server_sock) = make_tcp_pair();
|
||||
|
||||
let client_jh = std::thread::spawn(move || {
|
||||
let mut buf = BytesMut::new();
|
||||
|
||||
// SSLRequest
|
||||
buf.put_u32(8);
|
||||
buf.put_u32(80877103);
|
||||
client_sock.write_all(&buf).unwrap();
|
||||
buf.clear();
|
||||
|
||||
let ssl_response = client_sock.read_u8().unwrap();
|
||||
assert_eq!(b'N', ssl_response);
|
||||
});
|
||||
|
||||
struct TestHandler;
|
||||
|
||||
impl Handler for TestHandler {
|
||||
fn process_query(
|
||||
&mut self,
|
||||
_pgb: &mut PostgresBackend,
|
||||
_query_string: &str,
|
||||
) -> Result<(), QueryError> {
|
||||
panic!()
|
||||
}
|
||||
}
|
||||
|
||||
let mut handler = TestHandler;
|
||||
|
||||
let pgb = PostgresBackend::new(server_sock, AuthType::Trust, None, true).unwrap();
|
||||
pgb.run(&mut handler).unwrap();
|
||||
|
||||
client_jh.join().unwrap();
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn server_forces_ssl() {
|
||||
let (mut client_sock, server_sock) = make_tcp_pair();
|
||||
|
||||
let client_jh = std::thread::spawn(move || {
|
||||
// StartupMessage
|
||||
client_sock.write_u32::<BigEndian>(9).unwrap();
|
||||
client_sock.write_u32::<BigEndian>(196608).unwrap();
|
||||
client_sock.write_u8(0).unwrap();
|
||||
client_sock.flush().unwrap();
|
||||
|
||||
// ErrorResponse
|
||||
assert_eq!(client_sock.read_u8().unwrap(), b'E');
|
||||
let len = client_sock.read_u32::<BigEndian>().unwrap() - 4;
|
||||
|
||||
let mut body = vec![0; len as usize];
|
||||
client_sock.read_exact(&mut body).unwrap();
|
||||
let mut body = Bytes::from(body);
|
||||
|
||||
let mut errors = HashMap::new();
|
||||
loop {
|
||||
let field_type = body.get_u8();
|
||||
if field_type == 0u8 {
|
||||
break;
|
||||
}
|
||||
|
||||
let end_idx = body.iter().position(|&b| b == 0u8).unwrap();
|
||||
let mut value = body.split_to(end_idx + 1);
|
||||
assert_eq!(value[end_idx], 0u8);
|
||||
value.truncate(end_idx);
|
||||
let old = errors.insert(field_type, value);
|
||||
assert!(old.is_none());
|
||||
}
|
||||
|
||||
assert!(!body.has_remaining());
|
||||
|
||||
assert_eq!("must connect with TLS", errors.get(&b'M').unwrap());
|
||||
|
||||
// TODO read failure
|
||||
});
|
||||
|
||||
struct TestHandler;
|
||||
impl Handler for TestHandler {
|
||||
fn process_query(
|
||||
&mut self,
|
||||
_pgb: &mut PostgresBackend,
|
||||
_query_string: &str,
|
||||
) -> Result<(), QueryError> {
|
||||
panic!()
|
||||
}
|
||||
}
|
||||
let mut handler = TestHandler;
|
||||
|
||||
let cfg = rustls::ServerConfig::builder()
|
||||
.with_safe_defaults()
|
||||
.with_no_client_auth()
|
||||
.with_single_cert(vec![CERT.clone()], KEY.clone())
|
||||
.unwrap();
|
||||
let tls_config = Some(Arc::new(cfg));
|
||||
|
||||
let pgb = PostgresBackend::new(server_sock, AuthType::Trust, tls_config, true).unwrap();
|
||||
let res = pgb.run(&mut handler).unwrap_err();
|
||||
assert_eq!("client did not connect with TLS", format!("{}", res));
|
||||
|
||||
client_jh.join().unwrap();
|
||||
|
||||
// TODO consider shutdown behavior
|
||||
}
|
||||
4
libs/walproposer/.gitignore
vendored
Normal file
4
libs/walproposer/.gitignore
vendored
Normal file
@@ -0,0 +1,4 @@
|
||||
*.a
|
||||
*.o
|
||||
*.tmp
|
||||
pgdata
|
||||
39
libs/walproposer/Cargo.toml
Normal file
39
libs/walproposer/Cargo.toml
Normal file
@@ -0,0 +1,39 @@
|
||||
[package]
|
||||
name = "walproposer"
|
||||
version = "0.1.0"
|
||||
edition.workspace = true
|
||||
license.workspace = true
|
||||
|
||||
[dependencies]
|
||||
atty.workspace = true
|
||||
rand.workspace = true
|
||||
regex.workspace = true
|
||||
bytes.workspace = true
|
||||
byteorder.workspace = true
|
||||
anyhow.workspace = true
|
||||
crc32c.workspace = true
|
||||
hex.workspace = true
|
||||
once_cell.workspace = true
|
||||
log.workspace = true
|
||||
libc.workspace = true
|
||||
memoffset.workspace = true
|
||||
thiserror.workspace = true
|
||||
tracing.workspace = true
|
||||
tracing-subscriber = { workspace = true, features = ["json"] }
|
||||
serde.workspace = true
|
||||
scopeguard.workspace = true
|
||||
utils.workspace = true
|
||||
safekeeper.workspace = true
|
||||
postgres_ffi.workspace = true
|
||||
hyper.workspace = true
|
||||
|
||||
workspace_hack.workspace = true
|
||||
|
||||
[dev-dependencies]
|
||||
env_logger.workspace = true
|
||||
postgres.workspace = true
|
||||
|
||||
[build-dependencies]
|
||||
anyhow.workspace = true
|
||||
bindgen.workspace = true
|
||||
cbindgen = "0.24.0"
|
||||
16
libs/walproposer/README.md
Normal file
16
libs/walproposer/README.md
Normal file
@@ -0,0 +1,16 @@
|
||||
# walproposer Rust module
|
||||
|
||||
## Rust -> C
|
||||
|
||||
We compile walproposer as a static library and generate Rust bindings for it using `bindgen`.
|
||||
Entrypoint header file is `bindgen_deps.h`.
|
||||
|
||||
## C -> Rust
|
||||
|
||||
We use `cbindgen` to generate C bindings for the Rust code. They are stored in `rust_bindings.h`.
|
||||
|
||||
## How to run the tests
|
||||
|
||||
```
|
||||
export RUSTFLAGS="-C default-linker-libraries"
|
||||
```
|
||||
30
libs/walproposer/bindgen_deps.h
Normal file
30
libs/walproposer/bindgen_deps.h
Normal file
@@ -0,0 +1,30 @@
|
||||
/*
|
||||
* This header file is the input to bindgen. It includes all the
|
||||
* PostgreSQL headers that we need to auto-generate Rust structs
|
||||
* from. If you need to expose a new struct to Rust code, add the
|
||||
* header here, and whitelist the struct in the build.rs file.
|
||||
*/
|
||||
#include "c.h"
|
||||
#include "walproposer.h"
|
||||
|
||||
#include <stdarg.h>
|
||||
#include <stdbool.h>
|
||||
#include <stdint.h>
|
||||
#include <stdlib.h>
|
||||
|
||||
// Calc a sum of two numbers. Used to test Rust->C function calls.
|
||||
int TestFunc(int a, int b);
|
||||
|
||||
// Run a client for simple simlib test.
|
||||
void RunClientC(uint32_t serverId);
|
||||
|
||||
void WalProposerRust();
|
||||
|
||||
void WalProposerCleanup();
|
||||
|
||||
extern bool debug_enabled;
|
||||
|
||||
// Initialize global variables before calling any Postgres C code.
|
||||
void MyContextInit();
|
||||
|
||||
XLogRecPtr MyInsertRecord();
|
||||
137
libs/walproposer/build.rs
Normal file
137
libs/walproposer/build.rs
Normal file
@@ -0,0 +1,137 @@
|
||||
use std::{env, path::PathBuf, process::Command};
|
||||
use anyhow::{anyhow, Context};
|
||||
use bindgen::CargoCallbacks;
|
||||
|
||||
extern crate bindgen;
|
||||
|
||||
fn main() -> anyhow::Result<()> {
|
||||
let crate_dir = env::var("CARGO_MANIFEST_DIR").unwrap();
|
||||
|
||||
cbindgen::Builder::new()
|
||||
.with_crate(crate_dir)
|
||||
.with_language(cbindgen::Language::C)
|
||||
.generate()
|
||||
.expect("Unable to generate bindings")
|
||||
.write_to_file("rust_bindings.h");
|
||||
|
||||
// Tell cargo to invalidate the built crate whenever the wrapper changes
|
||||
println!("cargo:rerun-if-changed=bindgen_deps.h,test.c,../../pgxn/neon/walproposer.c,build.sh");
|
||||
println!("cargo:rustc-link-arg=-Wl,--start-group");
|
||||
println!("cargo:rustc-link-arg=-lsim");
|
||||
println!("cargo:rustc-link-arg=-lpgport_srv");
|
||||
println!("cargo:rustc-link-arg=-lpostgres");
|
||||
println!("cargo:rustc-link-arg=-lpgcommon_srv");
|
||||
println!("cargo:rustc-link-arg=-lssl");
|
||||
println!("cargo:rustc-link-arg=-lcrypto");
|
||||
println!("cargo:rustc-link-arg=-lz");
|
||||
println!("cargo:rustc-link-arg=-lpthread");
|
||||
println!("cargo:rustc-link-arg=-lrt");
|
||||
println!("cargo:rustc-link-arg=-ldl");
|
||||
println!("cargo:rustc-link-arg=-lm");
|
||||
println!("cargo:rustc-link-arg=-lwalproposer");
|
||||
println!("cargo:rustc-link-arg=-Wl,--end-group");
|
||||
println!("cargo:rustc-link-search=/home/admin/simulator/libs/walproposer");
|
||||
// disable fPIE
|
||||
println!("cargo:rustc-link-arg=-no-pie");
|
||||
|
||||
// print output of build.sh
|
||||
let output = std::process::Command::new("./build.sh")
|
||||
.output()
|
||||
.expect("could not spawn `clang`");
|
||||
|
||||
println!("stdout: {}", String::from_utf8(output.stdout).unwrap());
|
||||
println!("stderr: {}", String::from_utf8(output.stderr).unwrap());
|
||||
|
||||
if !output.status.success() {
|
||||
// Panic if the command was not successful.
|
||||
panic!("could not compile object file");
|
||||
}
|
||||
|
||||
// // Finding the location of C headers for the Postgres server:
|
||||
// // - if POSTGRES_INSTALL_DIR is set look into it, otherwise look into `<project_root>/pg_install`
|
||||
// // - if there's a `bin/pg_config` file use it for getting include server, otherwise use `<project_root>/pg_install/{PG_MAJORVERSION}/include/postgresql/server`
|
||||
let pg_install_dir = if let Some(postgres_install_dir) = env::var_os("POSTGRES_INSTALL_DIR") {
|
||||
postgres_install_dir.into()
|
||||
} else {
|
||||
PathBuf::from("pg_install")
|
||||
};
|
||||
|
||||
let pg_version = "v15";
|
||||
let mut pg_install_dir_versioned = pg_install_dir.join(pg_version);
|
||||
if pg_install_dir_versioned.is_relative() {
|
||||
let cwd = env::current_dir().context("Failed to get current_dir")?;
|
||||
pg_install_dir_versioned = cwd.join("..").join("..").join(pg_install_dir_versioned);
|
||||
}
|
||||
|
||||
let pg_config_bin = pg_install_dir_versioned
|
||||
.join(pg_version)
|
||||
.join("bin")
|
||||
.join("pg_config");
|
||||
let inc_server_path: String = if pg_config_bin.exists() {
|
||||
let output = Command::new(pg_config_bin)
|
||||
.arg("--includedir-server")
|
||||
.output()
|
||||
.context("failed to execute `pg_config --includedir-server`")?;
|
||||
|
||||
if !output.status.success() {
|
||||
panic!("`pg_config --includedir-server` failed")
|
||||
}
|
||||
|
||||
String::from_utf8(output.stdout)
|
||||
.context("pg_config output is not UTF-8")?
|
||||
.trim_end()
|
||||
.into()
|
||||
} else {
|
||||
let server_path = pg_install_dir_versioned
|
||||
.join("include")
|
||||
.join("postgresql")
|
||||
.join("server")
|
||||
.into_os_string();
|
||||
server_path
|
||||
.into_string()
|
||||
.map_err(|s| anyhow!("Bad postgres server path {s:?}"))?
|
||||
};
|
||||
|
||||
let inc_pgxn_path = "/home/admin/simulator/pgxn/neon";
|
||||
|
||||
// The bindgen::Builder is the main entry point
|
||||
// to bindgen, and lets you build up options for
|
||||
// the resulting bindings.
|
||||
let bindings = bindgen::Builder::default()
|
||||
// The input header we would like to generate
|
||||
// bindings for.
|
||||
.header("bindgen_deps.h")
|
||||
// Tell cargo to invalidate the built crate whenever any of the
|
||||
// included header files changed.
|
||||
.parse_callbacks(Box::new(CargoCallbacks))
|
||||
.allowlist_function("TestFunc")
|
||||
.allowlist_function("RunClientC")
|
||||
.allowlist_function("WalProposerRust")
|
||||
.allowlist_function("MyContextInit")
|
||||
.allowlist_function("WalProposerCleanup")
|
||||
.allowlist_function("MyInsertRecord")
|
||||
.allowlist_var("wal_acceptors_list")
|
||||
.allowlist_var("wal_acceptor_reconnect_timeout")
|
||||
.allowlist_var("wal_acceptor_connection_timeout")
|
||||
.allowlist_var("am_wal_proposer")
|
||||
.allowlist_var("neon_timeline_walproposer")
|
||||
.allowlist_var("neon_tenant_walproposer")
|
||||
.allowlist_var("syncSafekeepers")
|
||||
.allowlist_var("sim_redo_start_lsn")
|
||||
.allowlist_var("debug_enabled")
|
||||
.clang_arg(format!("-I{inc_server_path}"))
|
||||
.clang_arg(format!("-I{inc_pgxn_path}"))
|
||||
.clang_arg(format!("-DSIMLIB"))
|
||||
// Finish the builder and generate the bindings.
|
||||
.generate()
|
||||
// Unwrap the Result and panic on failure.
|
||||
.expect("Unable to generate bindings");
|
||||
|
||||
// Write the bindings to the $OUT_DIR/bindings.rs file.
|
||||
let out_path = PathBuf::from(env::var("OUT_DIR").unwrap()).join("bindings.rs");
|
||||
bindings
|
||||
.write_to_file(out_path)
|
||||
.expect("Couldn't write bindings!");
|
||||
|
||||
Ok(())
|
||||
}
|
||||
21
libs/walproposer/build.sh
Executable file
21
libs/walproposer/build.sh
Executable file
@@ -0,0 +1,21 @@
|
||||
#!/bin/bash
|
||||
set -e
|
||||
|
||||
cd /home/admin/simulator/libs/walproposer
|
||||
|
||||
# TODO: rewrite to Makefile
|
||||
|
||||
make -C ../.. neon-pg-ext-walproposer
|
||||
make -C ../../pg_install/build/v15/src/backend postgres-lib -s
|
||||
cp ../../pg_install/build/v15/src/backend/libpostgres.a .
|
||||
cp ../../pg_install/build/v15/src/common/libpgcommon_srv.a .
|
||||
cp ../../pg_install/build/v15/src/port/libpgport_srv.a .
|
||||
|
||||
clang -g -c libpqwalproposer.c test.c -ferror-limit=1 -I ../../pg_install/v15/include/postgresql/server -I ../../pgxn/neon
|
||||
rm -rf libsim.a
|
||||
ar rcs libsim.a test.o libpqwalproposer.o
|
||||
|
||||
rm -rf libwalproposer.a
|
||||
|
||||
PGXN_DIR=../../pg_install/build/neon-v15/
|
||||
ar rcs libwalproposer.a $PGXN_DIR/walproposer.o $PGXN_DIR/walproposer_utils.o $PGXN_DIR/neon.o
|
||||
542
libs/walproposer/libpqwalproposer.c
Normal file
542
libs/walproposer/libpqwalproposer.c
Normal file
@@ -0,0 +1,542 @@
|
||||
#include "postgres.h"
|
||||
#include "neon.h"
|
||||
#include "walproposer.h"
|
||||
#include "rust_bindings.h"
|
||||
#include "replication/message.h"
|
||||
#include "access/xlog_internal.h"
|
||||
|
||||
// defined in walproposer.h
|
||||
uint64 sim_redo_start_lsn;
|
||||
XLogRecPtr sim_latest_available_lsn;
|
||||
|
||||
/* Header in walproposer.h -- Wrapper struct to abstract away the libpq connection */
|
||||
struct WalProposerConn
|
||||
{
|
||||
int64_t tcp;
|
||||
};
|
||||
|
||||
/* Helper function */
|
||||
static bool
|
||||
ensure_nonblocking_status(WalProposerConn *conn, bool is_nonblocking)
|
||||
{
|
||||
// walprop_log(LOG, "not implemented");
|
||||
return false;
|
||||
}
|
||||
|
||||
/* Exported function definitions */
|
||||
char *
|
||||
walprop_error_message(WalProposerConn *conn)
|
||||
{
|
||||
// walprop_log(LOG, "not implemented");
|
||||
return NULL;
|
||||
}
|
||||
|
||||
WalProposerConnStatusType
|
||||
walprop_status(WalProposerConn *conn)
|
||||
{
|
||||
// walprop_log(LOG, "not implemented: walprop_status");
|
||||
return WP_CONNECTION_OK;
|
||||
}
|
||||
|
||||
WalProposerConn *
|
||||
walprop_connect_start(char *conninfo)
|
||||
{
|
||||
WalProposerConn *conn;
|
||||
|
||||
walprop_log(LOG, "walprop_connect_start: %s", conninfo);
|
||||
|
||||
const char *connstr_prefix = "host=node port=";
|
||||
Assert(strncmp(conninfo, connstr_prefix, strlen(connstr_prefix)) == 0);
|
||||
|
||||
int nodeId = atoi(conninfo + strlen(connstr_prefix));
|
||||
|
||||
conn = palloc(sizeof(WalProposerConn));
|
||||
conn->tcp = sim_open_tcp(nodeId);
|
||||
return conn;
|
||||
}
|
||||
|
||||
WalProposerConnectPollStatusType
|
||||
walprop_connect_poll(WalProposerConn *conn)
|
||||
{
|
||||
// walprop_log(LOG, "not implemented: walprop_connect_poll");
|
||||
return WP_CONN_POLLING_OK;
|
||||
}
|
||||
|
||||
bool
|
||||
walprop_send_query(WalProposerConn *conn, char *query)
|
||||
{
|
||||
// walprop_log(LOG, "not implemented: walprop_send_query");
|
||||
return true;
|
||||
}
|
||||
|
||||
WalProposerExecStatusType
|
||||
walprop_get_query_result(WalProposerConn *conn)
|
||||
{
|
||||
// walprop_log(LOG, "not implemented: walprop_get_query_result");
|
||||
return WP_EXEC_SUCCESS_COPYBOTH;
|
||||
}
|
||||
|
||||
pgsocket
|
||||
walprop_socket(WalProposerConn *conn)
|
||||
{
|
||||
return (pgsocket) conn->tcp;
|
||||
}
|
||||
|
||||
int
|
||||
walprop_flush(WalProposerConn *conn)
|
||||
{
|
||||
// walprop_log(LOG, "not implemented");
|
||||
return 0;
|
||||
}
|
||||
|
||||
void
|
||||
walprop_finish(WalProposerConn *conn)
|
||||
{
|
||||
// walprop_log(LOG, "walprop_finish not implemented");
|
||||
}
|
||||
|
||||
/*
|
||||
* Receive a message from the safekeeper.
|
||||
*
|
||||
* On success, the data is placed in *buf. It is valid until the next call
|
||||
* to this function.
|
||||
*/
|
||||
PGAsyncReadResult
|
||||
walprop_async_read(WalProposerConn *conn, char **buf, int *amount)
|
||||
{
|
||||
uintptr_t len;
|
||||
char *msg;
|
||||
Event event;
|
||||
|
||||
event = sim_epoll_peek(0);
|
||||
if (event.tcp != conn->tcp || event.tag != Message || event.any_message != Bytes)
|
||||
return PG_ASYNC_READ_TRY_AGAIN;
|
||||
|
||||
event = sim_epoll_rcv(0);
|
||||
|
||||
// walprop_log(LOG, "walprop_async_read, T: %d, tcp: %d, tag: %d", (int) event.tag, (int) event.tcp, (int) event.any_message);
|
||||
Assert(event.tcp == conn->tcp);
|
||||
Assert(event.tag == Message);
|
||||
Assert(event.any_message == Bytes);
|
||||
|
||||
msg = (char*) sim_msg_get_bytes(&len);
|
||||
*buf = msg;
|
||||
*amount = len;
|
||||
// walprop_log(LOG, "walprop_async_read: %d", (int) len);
|
||||
|
||||
return PG_ASYNC_READ_SUCCESS;
|
||||
}
|
||||
|
||||
PGAsyncWriteResult
|
||||
walprop_async_write(WalProposerConn *conn, void const *buf, size_t size)
|
||||
{
|
||||
// walprop_log(LOG, "walprop_async_write");
|
||||
sim_msg_set_bytes(buf, size);
|
||||
sim_tcp_send(conn->tcp);
|
||||
return PG_ASYNC_WRITE_SUCCESS;
|
||||
}
|
||||
|
||||
/*
|
||||
* This function is very similar to walprop_async_write. For more
|
||||
* information, refer to the comments there.
|
||||
*/
|
||||
bool
|
||||
walprop_blocking_write(WalProposerConn *conn, void const *buf, size_t size)
|
||||
{
|
||||
// walprop_log(LOG, "walprop_blocking_write");
|
||||
sim_msg_set_bytes(buf, size);
|
||||
sim_tcp_send(conn->tcp);
|
||||
return true;
|
||||
}
|
||||
|
||||
void
|
||||
sim_start_replication(XLogRecPtr startptr)
|
||||
{
|
||||
walprop_log(LOG, "sim_start_replication: %X/%X", LSN_FORMAT_ARGS(startptr));
|
||||
sim_latest_available_lsn = startptr;
|
||||
|
||||
for (;;)
|
||||
{
|
||||
XLogRecPtr endptr = sim_latest_available_lsn;
|
||||
|
||||
Assert(startptr <= endptr);
|
||||
if (endptr > startptr)
|
||||
{
|
||||
WalProposerBroadcast(startptr, endptr);
|
||||
startptr = endptr;
|
||||
}
|
||||
|
||||
WalProposerPoll();
|
||||
}
|
||||
}
|
||||
|
||||
#define UsableBytesInPage (XLOG_BLCKSZ - SizeOfXLogShortPHD)
|
||||
|
||||
static int UsableBytesInSegment =
|
||||
(DEFAULT_XLOG_SEG_SIZE / XLOG_BLCKSZ * UsableBytesInPage) -
|
||||
(SizeOfXLogLongPHD - SizeOfXLogShortPHD);
|
||||
|
||||
/*
|
||||
* Converts a "usable byte position" to XLogRecPtr. A usable byte position
|
||||
* is the position starting from the beginning of WAL, excluding all WAL
|
||||
* page headers.
|
||||
*/
|
||||
static XLogRecPtr
|
||||
XLogBytePosToRecPtr(uint64 bytepos)
|
||||
{
|
||||
uint64 fullsegs;
|
||||
uint64 fullpages;
|
||||
uint64 bytesleft;
|
||||
uint32 seg_offset;
|
||||
XLogRecPtr result;
|
||||
|
||||
fullsegs = bytepos / UsableBytesInSegment;
|
||||
bytesleft = bytepos % UsableBytesInSegment;
|
||||
|
||||
if (bytesleft < XLOG_BLCKSZ - SizeOfXLogLongPHD)
|
||||
{
|
||||
/* fits on first page of segment */
|
||||
seg_offset = bytesleft + SizeOfXLogLongPHD;
|
||||
}
|
||||
else
|
||||
{
|
||||
/* account for the first page on segment with long header */
|
||||
seg_offset = XLOG_BLCKSZ;
|
||||
bytesleft -= XLOG_BLCKSZ - SizeOfXLogLongPHD;
|
||||
|
||||
fullpages = bytesleft / UsableBytesInPage;
|
||||
bytesleft = bytesleft % UsableBytesInPage;
|
||||
|
||||
seg_offset += fullpages * XLOG_BLCKSZ + bytesleft + SizeOfXLogShortPHD;
|
||||
}
|
||||
|
||||
XLogSegNoOffsetToRecPtr(fullsegs, seg_offset, wal_segment_size, result);
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
/*
|
||||
* Convert an XLogRecPtr to a "usable byte position".
|
||||
*/
|
||||
static uint64
|
||||
XLogRecPtrToBytePos(XLogRecPtr ptr)
|
||||
{
|
||||
uint64 fullsegs;
|
||||
uint32 fullpages;
|
||||
uint32 offset;
|
||||
uint64 result;
|
||||
|
||||
XLByteToSeg(ptr, fullsegs, wal_segment_size);
|
||||
|
||||
fullpages = (XLogSegmentOffset(ptr, wal_segment_size)) / XLOG_BLCKSZ;
|
||||
offset = ptr % XLOG_BLCKSZ;
|
||||
|
||||
if (fullpages == 0)
|
||||
{
|
||||
result = fullsegs * UsableBytesInSegment;
|
||||
if (offset > 0)
|
||||
{
|
||||
Assert(offset >= SizeOfXLogLongPHD);
|
||||
result += offset - SizeOfXLogLongPHD;
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
result = fullsegs * UsableBytesInSegment +
|
||||
(XLOG_BLCKSZ - SizeOfXLogLongPHD) + /* account for first page */
|
||||
(fullpages - 1) * UsableBytesInPage; /* full pages */
|
||||
if (offset > 0)
|
||||
{
|
||||
Assert(offset >= SizeOfXLogShortPHD);
|
||||
result += offset - SizeOfXLogShortPHD;
|
||||
}
|
||||
}
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
#define max_rdatas 16
|
||||
|
||||
void InitMyInsert();
|
||||
static void MyBeginInsert();
|
||||
static void MyRegisterData(char *data, int len);
|
||||
static XLogRecPtr MyFinishInsert(RmgrId rmid, uint8 info, uint8 flags);
|
||||
static void MyCopyXLogRecordToWAL(int write_len, XLogRecData *rdata, XLogRecPtr StartPos, XLogRecPtr EndPos);
|
||||
|
||||
/*
|
||||
* An array of XLogRecData structs, to hold registered data.
|
||||
*/
|
||||
static XLogRecData rdatas[max_rdatas];
|
||||
static int num_rdatas; /* entries currently used */
|
||||
static uint32 mainrdata_len; /* total # of bytes in chain */
|
||||
static XLogRecData hdr_rdt;
|
||||
static char hdr_scratch[16000];
|
||||
static XLogRecPtr CurrBytePos;
|
||||
static XLogRecPtr PrevBytePos;
|
||||
|
||||
void InitMyInsert()
|
||||
{
|
||||
CurrBytePos = sim_redo_start_lsn;
|
||||
PrevBytePos = InvalidXLogRecPtr;
|
||||
sim_latest_available_lsn = sim_redo_start_lsn;
|
||||
}
|
||||
|
||||
static void MyBeginInsert()
|
||||
{
|
||||
num_rdatas = 0;
|
||||
mainrdata_len = 0;
|
||||
}
|
||||
|
||||
static void MyRegisterData(char *data, int len)
|
||||
{
|
||||
XLogRecData *rdata;
|
||||
|
||||
if (num_rdatas >= max_rdatas)
|
||||
walprop_log(ERROR, "too much WAL data");
|
||||
rdata = &rdatas[num_rdatas++];
|
||||
|
||||
rdata->data = data;
|
||||
rdata->len = len;
|
||||
rdata->next = NULL;
|
||||
|
||||
if (num_rdatas > 1) {
|
||||
rdatas[num_rdatas - 2].next = rdata;
|
||||
}
|
||||
|
||||
mainrdata_len += len;
|
||||
}
|
||||
|
||||
static XLogRecPtr
|
||||
MyFinishInsert(RmgrId rmid, uint8 info, uint8 flags)
|
||||
{
|
||||
XLogRecData *rdt;
|
||||
uint32 total_len = 0;
|
||||
int block_id;
|
||||
pg_crc32c rdata_crc;
|
||||
XLogRecord *rechdr;
|
||||
char *scratch = hdr_scratch;
|
||||
int size;
|
||||
XLogRecPtr StartPos;
|
||||
XLogRecPtr EndPos;
|
||||
uint64 startbytepos;
|
||||
uint64 endbytepos;
|
||||
|
||||
/*
|
||||
* Note: this function can be called multiple times for the same record.
|
||||
* All the modifications we do to the rdata chains below must handle that.
|
||||
*/
|
||||
|
||||
/* The record begins with the fixed-size header */
|
||||
rechdr = (XLogRecord *) scratch;
|
||||
scratch += SizeOfXLogRecord;
|
||||
|
||||
hdr_rdt.data = hdr_scratch;
|
||||
|
||||
if (num_rdatas > 0)
|
||||
{
|
||||
hdr_rdt.next = &rdatas[0];
|
||||
}
|
||||
else
|
||||
{
|
||||
hdr_rdt.next = NULL;
|
||||
}
|
||||
|
||||
/* followed by main data, if any */
|
||||
if (mainrdata_len > 0)
|
||||
{
|
||||
if (mainrdata_len > 255)
|
||||
{
|
||||
*(scratch++) = (char) XLR_BLOCK_ID_DATA_LONG;
|
||||
memcpy(scratch, &mainrdata_len, sizeof(uint32));
|
||||
scratch += sizeof(uint32);
|
||||
}
|
||||
else
|
||||
{
|
||||
*(scratch++) = (char) XLR_BLOCK_ID_DATA_SHORT;
|
||||
*(scratch++) = (uint8) mainrdata_len;
|
||||
}
|
||||
total_len += mainrdata_len;
|
||||
}
|
||||
|
||||
hdr_rdt.len = (scratch - hdr_scratch);
|
||||
total_len += hdr_rdt.len;
|
||||
|
||||
/*
|
||||
* Calculate CRC of the data
|
||||
*
|
||||
* Note that the record header isn't added into the CRC initially since we
|
||||
* don't know the prev-link yet. Thus, the CRC will represent the CRC of
|
||||
* the whole record in the order: rdata, then backup blocks, then record
|
||||
* header.
|
||||
*/
|
||||
INIT_CRC32C(rdata_crc);
|
||||
COMP_CRC32C(rdata_crc, hdr_scratch + SizeOfXLogRecord, hdr_rdt.len - SizeOfXLogRecord);
|
||||
for (size_t i = 0; i < num_rdatas; i++)
|
||||
{
|
||||
rdt = &rdatas[i];
|
||||
COMP_CRC32C(rdata_crc, rdt->data, rdt->len);
|
||||
}
|
||||
|
||||
/*
|
||||
* Fill in the fields in the record header. Prev-link is filled in later,
|
||||
* once we know where in the WAL the record will be inserted. The CRC does
|
||||
* not include the record header yet.
|
||||
*/
|
||||
rechdr->xl_xid = 0;
|
||||
rechdr->xl_tot_len = total_len;
|
||||
rechdr->xl_info = info;
|
||||
rechdr->xl_rmid = rmid;
|
||||
rechdr->xl_prev = InvalidXLogRecPtr;
|
||||
rechdr->xl_crc = rdata_crc;
|
||||
|
||||
size = MAXALIGN(rechdr->xl_tot_len);
|
||||
|
||||
/* All (non xlog-switch) records should contain data. */
|
||||
Assert(size > SizeOfXLogRecord);
|
||||
|
||||
startbytepos = XLogRecPtrToBytePos(CurrBytePos);
|
||||
endbytepos = startbytepos + size;
|
||||
|
||||
// Get the position.
|
||||
StartPos = XLogBytePosToRecPtr(startbytepos);
|
||||
EndPos = XLogBytePosToRecPtr(startbytepos + size);
|
||||
rechdr->xl_prev = PrevBytePos;
|
||||
|
||||
Assert(XLogRecPtrToBytePos(StartPos) == startbytepos);
|
||||
Assert(XLogRecPtrToBytePos(EndPos) == endbytepos);
|
||||
|
||||
// Update global pointers.
|
||||
CurrBytePos = EndPos;
|
||||
PrevBytePos = StartPos;
|
||||
|
||||
/*
|
||||
* Now that xl_prev has been filled in, calculate CRC of the record
|
||||
* header.
|
||||
*/
|
||||
rdata_crc = rechdr->xl_crc;
|
||||
COMP_CRC32C(rdata_crc, rechdr, offsetof(XLogRecord, xl_crc));
|
||||
FIN_CRC32C(rdata_crc);
|
||||
rechdr->xl_crc = rdata_crc;
|
||||
|
||||
// Now write it to disk.
|
||||
MyCopyXLogRecordToWAL(rechdr->xl_tot_len, &hdr_rdt, StartPos, EndPos);
|
||||
return EndPos;
|
||||
}
|
||||
|
||||
#define INSERT_FREESPACE(endptr) \
|
||||
(((endptr) % XLOG_BLCKSZ == 0) ? 0 : (XLOG_BLCKSZ - (endptr) % XLOG_BLCKSZ))
|
||||
|
||||
static void
|
||||
MyCopyXLogRecordToWAL(int write_len, XLogRecData *rdata, XLogRecPtr StartPos, XLogRecPtr EndPos)
|
||||
{
|
||||
XLogRecPtr CurrPos;
|
||||
int written;
|
||||
int freespace;
|
||||
|
||||
// Write hdr_rdt and `num_rdatas` other datas.
|
||||
CurrPos = StartPos;
|
||||
freespace = INSERT_FREESPACE(CurrPos);
|
||||
written = 0;
|
||||
|
||||
Assert(freespace >= sizeof(uint32));
|
||||
|
||||
while (rdata != NULL)
|
||||
{
|
||||
char *rdata_data = rdata->data;
|
||||
int rdata_len = rdata->len;
|
||||
|
||||
while (rdata_len >= freespace)
|
||||
{
|
||||
char header_buf[SizeOfXLogLongPHD];
|
||||
XLogPageHeader NewPage = (XLogPageHeader) header_buf;
|
||||
|
||||
Assert(CurrPos % XLOG_BLCKSZ >= SizeOfXLogShortPHD || freespace == 0);
|
||||
XLogWalPropWrite(rdata_data, freespace, CurrPos);
|
||||
rdata_data += freespace;
|
||||
rdata_len -= freespace;
|
||||
written += freespace;
|
||||
CurrPos += freespace;
|
||||
|
||||
// Init new page
|
||||
MemSet(header_buf, 0, SizeOfXLogLongPHD);
|
||||
|
||||
/*
|
||||
* Fill the new page's header
|
||||
*/
|
||||
NewPage->xlp_magic = XLOG_PAGE_MAGIC;
|
||||
|
||||
/* NewPage->xlp_info = 0; */ /* done by memset */
|
||||
NewPage->xlp_tli = 1;
|
||||
NewPage->xlp_pageaddr = CurrPos;
|
||||
|
||||
/* NewPage->xlp_rem_len = 0; */ /* done by memset */
|
||||
NewPage->xlp_info |= XLP_BKP_REMOVABLE;
|
||||
|
||||
/*
|
||||
* If first page of an XLOG segment file, make it a long header.
|
||||
*/
|
||||
if ((XLogSegmentOffset(NewPage->xlp_pageaddr, wal_segment_size)) == 0)
|
||||
{
|
||||
XLogLongPageHeader NewLongPage = (XLogLongPageHeader) NewPage;
|
||||
|
||||
NewLongPage->xlp_sysid = 0;
|
||||
NewLongPage->xlp_seg_size = wal_segment_size;
|
||||
NewLongPage->xlp_xlog_blcksz = XLOG_BLCKSZ;
|
||||
NewPage->xlp_info |= XLP_LONG_HEADER;
|
||||
}
|
||||
|
||||
NewPage->xlp_rem_len = write_len - written;
|
||||
if (NewPage->xlp_rem_len > 0) {
|
||||
NewPage->xlp_info |= XLP_FIRST_IS_CONTRECORD;
|
||||
}
|
||||
|
||||
/* skip over the page header */
|
||||
if (XLogSegmentOffset(CurrPos, wal_segment_size) == 0)
|
||||
{
|
||||
XLogWalPropWrite(header_buf, SizeOfXLogLongPHD, CurrPos);
|
||||
CurrPos += SizeOfXLogLongPHD;
|
||||
}
|
||||
else
|
||||
{
|
||||
XLogWalPropWrite(header_buf, SizeOfXLogShortPHD, CurrPos);
|
||||
CurrPos += SizeOfXLogShortPHD;
|
||||
}
|
||||
freespace = INSERT_FREESPACE(CurrPos);
|
||||
}
|
||||
|
||||
Assert(CurrPos % XLOG_BLCKSZ >= SizeOfXLogShortPHD || rdata_len == 0);
|
||||
XLogWalPropWrite(rdata_data, rdata_len, CurrPos);
|
||||
CurrPos += rdata_len;
|
||||
written += rdata_len;
|
||||
freespace -= rdata_len;
|
||||
|
||||
rdata = rdata->next;
|
||||
}
|
||||
|
||||
Assert(written == write_len);
|
||||
CurrPos = MAXALIGN64(CurrPos);
|
||||
Assert(CurrPos == EndPos);
|
||||
}
|
||||
|
||||
XLogRecPtr MyInsertRecord()
|
||||
{
|
||||
const char *prefix = "prefix";
|
||||
const char *message = "message";
|
||||
size_t size = 7;
|
||||
bool transactional = false;
|
||||
|
||||
xl_logical_message xlrec;
|
||||
|
||||
xlrec.dbId = 0;
|
||||
xlrec.transactional = transactional;
|
||||
/* trailing zero is critical; see logicalmsg_desc */
|
||||
xlrec.prefix_size = strlen(prefix) + 1;
|
||||
xlrec.message_size = size;
|
||||
|
||||
MyBeginInsert();
|
||||
MyRegisterData((char *) &xlrec, SizeOfLogicalMessage);
|
||||
MyRegisterData(unconstify(char *, prefix), xlrec.prefix_size);
|
||||
MyRegisterData(unconstify(char *, message), size);
|
||||
|
||||
return MyFinishInsert(RM_LOGICALMSG_ID, XLOG_LOGICAL_MESSAGE, XLOG_INCLUDE_ORIGIN);
|
||||
}
|
||||
106
libs/walproposer/rust_bindings.h
Normal file
106
libs/walproposer/rust_bindings.h
Normal file
@@ -0,0 +1,106 @@
|
||||
#include <stdarg.h>
|
||||
#include <stdbool.h>
|
||||
#include <stdint.h>
|
||||
#include <stdlib.h>
|
||||
|
||||
/**
|
||||
* List of all possible AnyMessage.
|
||||
*/
|
||||
enum AnyMessageTag {
|
||||
None,
|
||||
InternalConnect,
|
||||
Just32,
|
||||
ReplCell,
|
||||
Bytes,
|
||||
LSN,
|
||||
};
|
||||
typedef uint8_t AnyMessageTag;
|
||||
|
||||
/**
|
||||
* List of all possible NodeEvent.
|
||||
*/
|
||||
enum EventTag {
|
||||
Timeout,
|
||||
Accept,
|
||||
Closed,
|
||||
Message,
|
||||
Internal,
|
||||
};
|
||||
typedef uint8_t EventTag;
|
||||
|
||||
/**
|
||||
* Event returned by epoll_recv.
|
||||
*/
|
||||
typedef struct Event {
|
||||
EventTag tag;
|
||||
int64_t tcp;
|
||||
AnyMessageTag any_message;
|
||||
} Event;
|
||||
|
||||
void rust_function(uint32_t a);
|
||||
|
||||
/**
|
||||
* C API for the node os.
|
||||
*/
|
||||
void sim_sleep(uint64_t ms);
|
||||
|
||||
uint64_t sim_random(uint64_t max);
|
||||
|
||||
uint32_t sim_id(void);
|
||||
|
||||
int64_t sim_open_tcp(uint32_t dst);
|
||||
|
||||
int64_t sim_open_tcp_nopoll(uint32_t dst);
|
||||
|
||||
/**
|
||||
* Send MESSAGE_BUF content to the given tcp.
|
||||
*/
|
||||
void sim_tcp_send(int64_t tcp);
|
||||
|
||||
/**
|
||||
* Receive a message from the given tcp. Can be used only with tcp opened with
|
||||
* `sim_open_tcp_nopoll`.
|
||||
*/
|
||||
struct Event sim_tcp_recv(int64_t tcp);
|
||||
|
||||
struct Event sim_epoll_rcv(int64_t timeout);
|
||||
|
||||
struct Event sim_epoll_peek(int64_t timeout);
|
||||
|
||||
int64_t sim_now(void);
|
||||
|
||||
void sim_exit(int32_t code, const uint8_t *msg);
|
||||
|
||||
void sim_set_result(int32_t code, const uint8_t *msg);
|
||||
|
||||
void sim_log_event(const int8_t *msg);
|
||||
|
||||
/**
|
||||
* Get tag of the current message.
|
||||
*/
|
||||
AnyMessageTag sim_msg_tag(void);
|
||||
|
||||
/**
|
||||
* Read AnyMessage::Just32 message.
|
||||
*/
|
||||
void sim_msg_get_just_u32(uint32_t *val);
|
||||
|
||||
/**
|
||||
* Read AnyMessage::LSN message.
|
||||
*/
|
||||
void sim_msg_get_lsn(uint64_t *val);
|
||||
|
||||
/**
|
||||
* Write AnyMessage::ReplCell message.
|
||||
*/
|
||||
void sim_msg_set_repl_cell(uint32_t value, uint32_t client_id, uint32_t seqno);
|
||||
|
||||
/**
|
||||
* Write AnyMessage::Bytes message.
|
||||
*/
|
||||
void sim_msg_set_bytes(const char *bytes, uintptr_t len);
|
||||
|
||||
/**
|
||||
* Read AnyMessage::Bytes message.
|
||||
*/
|
||||
const char *sim_msg_get_bytes(uintptr_t *len);
|
||||
36
libs/walproposer/src/lib.rs
Normal file
36
libs/walproposer/src/lib.rs
Normal file
@@ -0,0 +1,36 @@
|
||||
#![allow(non_upper_case_globals)]
|
||||
#![allow(non_camel_case_types)]
|
||||
#![allow(non_snake_case)]
|
||||
|
||||
use safekeeper::simlib::node_os::NodeOs;
|
||||
use tracing::info;
|
||||
|
||||
pub mod bindings {
|
||||
include!(concat!(env!("OUT_DIR"), "/bindings.rs"));
|
||||
}
|
||||
|
||||
#[no_mangle]
|
||||
pub extern "C" fn rust_function(a: u32) {
|
||||
info!("Hello from Rust!");
|
||||
info!("a: {}", a);
|
||||
}
|
||||
|
||||
pub mod sim;
|
||||
pub mod sim_proto;
|
||||
|
||||
#[cfg(test)]
|
||||
mod test;
|
||||
|
||||
#[cfg(test)]
|
||||
pub mod simtest;
|
||||
|
||||
pub fn c_context() -> Option<Box<dyn Fn(NodeOs) + Send + Sync>> {
|
||||
Some(Box::new(|os: NodeOs| {
|
||||
sim::c_attach_node_os(os);
|
||||
unsafe { bindings::MyContextInit(); }
|
||||
}))
|
||||
}
|
||||
|
||||
pub fn enable_debug() {
|
||||
unsafe { bindings::debug_enabled = true; }
|
||||
}
|
||||
240
libs/walproposer/src/sim.rs
Normal file
240
libs/walproposer/src/sim.rs
Normal file
@@ -0,0 +1,240 @@
|
||||
use log::debug;
|
||||
use safekeeper::simlib::{network::TCP, node_os::NodeOs, world::NodeEvent};
|
||||
use std::{
|
||||
cell::RefCell,
|
||||
collections::HashMap,
|
||||
ffi::{CStr, CString},
|
||||
};
|
||||
use tracing::trace;
|
||||
|
||||
use crate::sim_proto::{anymessage_tag, AnyMessageTag, Event, EventTag, MESSAGE_BUF};
|
||||
|
||||
thread_local! {
|
||||
static CURRENT_NODE_OS: RefCell<Option<NodeOs>> = RefCell::new(None);
|
||||
static TCP_CACHE: RefCell<HashMap<i64, TCP>> = RefCell::new(HashMap::new());
|
||||
}
|
||||
|
||||
/// Get the current node os.
|
||||
fn os() -> NodeOs {
|
||||
CURRENT_NODE_OS.with(|cell| cell.borrow().clone().expect("no node os set"))
|
||||
}
|
||||
|
||||
fn tcp_save(tcp: TCP) -> i64 {
|
||||
TCP_CACHE.with(|cell| {
|
||||
let mut cache = cell.borrow_mut();
|
||||
let id = tcp.id();
|
||||
cache.insert(id, tcp);
|
||||
id
|
||||
})
|
||||
}
|
||||
|
||||
fn tcp_load(id: i64) -> TCP {
|
||||
TCP_CACHE.with(|cell| {
|
||||
let cache = cell.borrow();
|
||||
cache.get(&id).expect("unknown TCP id").clone()
|
||||
})
|
||||
}
|
||||
|
||||
/// Should be called before calling any of the C functions.
|
||||
pub(crate) fn c_attach_node_os(os: NodeOs) {
|
||||
CURRENT_NODE_OS.with(|cell| {
|
||||
*cell.borrow_mut() = Some(os);
|
||||
});
|
||||
TCP_CACHE.with(|cell| {
|
||||
*cell.borrow_mut() = HashMap::new();
|
||||
});
|
||||
}
|
||||
|
||||
/// C API for the node os.
|
||||
|
||||
#[no_mangle]
|
||||
pub extern "C" fn sim_sleep(ms: u64) {
|
||||
os().sleep(ms);
|
||||
}
|
||||
|
||||
#[no_mangle]
|
||||
pub extern "C" fn sim_random(max: u64) -> u64 {
|
||||
os().random(max)
|
||||
}
|
||||
|
||||
#[no_mangle]
|
||||
pub extern "C" fn sim_id() -> u32 {
|
||||
os().id().into()
|
||||
}
|
||||
|
||||
#[no_mangle]
|
||||
pub extern "C" fn sim_open_tcp(dst: u32) -> i64 {
|
||||
tcp_save(os().open_tcp(dst.into()))
|
||||
}
|
||||
|
||||
#[no_mangle]
|
||||
pub extern "C" fn sim_open_tcp_nopoll(dst: u32) -> i64 {
|
||||
tcp_save(os().open_tcp_nopoll(dst.into()))
|
||||
}
|
||||
|
||||
#[no_mangle]
|
||||
/// Send MESSAGE_BUF content to the given tcp.
|
||||
pub extern "C" fn sim_tcp_send(tcp: i64) {
|
||||
tcp_load(tcp).send(MESSAGE_BUF.with(|cell| cell.borrow().clone()));
|
||||
}
|
||||
|
||||
#[no_mangle]
|
||||
/// Receive a message from the given tcp. Can be used only with tcp opened with
|
||||
/// `sim_open_tcp_nopoll`.
|
||||
pub extern "C" fn sim_tcp_recv(tcp: i64) -> Event {
|
||||
let event = tcp_load(tcp).recv();
|
||||
match event {
|
||||
NodeEvent::Accept(_) => unreachable!(),
|
||||
NodeEvent::Closed(_) => Event {
|
||||
tag: EventTag::Closed,
|
||||
tcp: 0,
|
||||
any_message: AnyMessageTag::None,
|
||||
},
|
||||
NodeEvent::Internal(_) => unreachable!(),
|
||||
NodeEvent::Message((message, _)) => {
|
||||
// store message in thread local storage, C code should use
|
||||
// sim_msg_* functions to access it.
|
||||
MESSAGE_BUF.with(|cell| {
|
||||
*cell.borrow_mut() = message.clone();
|
||||
});
|
||||
Event {
|
||||
tag: EventTag::Message,
|
||||
tcp: 0,
|
||||
any_message: anymessage_tag(&message),
|
||||
}
|
||||
}
|
||||
NodeEvent::WakeTimeout(_) => unreachable!(),
|
||||
}
|
||||
}
|
||||
|
||||
#[no_mangle]
|
||||
pub extern "C" fn sim_epoll_rcv(timeout: i64) -> Event {
|
||||
let event = os().epoll_recv(timeout);
|
||||
let event = if let Some(event) = event {
|
||||
event
|
||||
} else {
|
||||
return Event {
|
||||
tag: EventTag::Timeout,
|
||||
tcp: 0,
|
||||
any_message: AnyMessageTag::None,
|
||||
};
|
||||
};
|
||||
|
||||
match event {
|
||||
NodeEvent::Accept(tcp) => Event {
|
||||
tag: EventTag::Accept,
|
||||
tcp: tcp_save(tcp),
|
||||
any_message: AnyMessageTag::None,
|
||||
},
|
||||
NodeEvent::Closed(tcp) => Event {
|
||||
tag: EventTag::Closed,
|
||||
tcp: tcp_save(tcp),
|
||||
any_message: AnyMessageTag::None,
|
||||
},
|
||||
NodeEvent::Message((message, tcp)) => {
|
||||
// store message in thread local storage, C code should use
|
||||
// sim_msg_* functions to access it.
|
||||
MESSAGE_BUF.with(|cell| {
|
||||
*cell.borrow_mut() = message.clone();
|
||||
});
|
||||
Event {
|
||||
tag: EventTag::Message,
|
||||
tcp: tcp_save(tcp),
|
||||
any_message: anymessage_tag(&message),
|
||||
}
|
||||
}
|
||||
NodeEvent::Internal(message) => {
|
||||
// store message in thread local storage, C code should use
|
||||
// sim_msg_* functions to access it.
|
||||
MESSAGE_BUF.with(|cell| {
|
||||
*cell.borrow_mut() = message.clone();
|
||||
});
|
||||
Event {
|
||||
tag: EventTag::Internal,
|
||||
tcp: 0,
|
||||
any_message: anymessage_tag(&message),
|
||||
}
|
||||
}
|
||||
NodeEvent::WakeTimeout(_) => {
|
||||
// can't happen
|
||||
unreachable!()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[no_mangle]
|
||||
pub extern "C" fn sim_epoll_peek(timeout: i64) -> Event {
|
||||
let event = os().epoll_peek(timeout);
|
||||
let event = if let Some(event) = event {
|
||||
event
|
||||
} else {
|
||||
return Event {
|
||||
tag: EventTag::Timeout,
|
||||
tcp: 0,
|
||||
any_message: AnyMessageTag::None,
|
||||
};
|
||||
};
|
||||
|
||||
match event {
|
||||
NodeEvent::Accept(tcp) => Event {
|
||||
tag: EventTag::Accept,
|
||||
tcp: tcp_save(tcp),
|
||||
any_message: AnyMessageTag::None,
|
||||
},
|
||||
NodeEvent::Closed(tcp) => Event {
|
||||
tag: EventTag::Closed,
|
||||
tcp: tcp_save(tcp),
|
||||
any_message: AnyMessageTag::None,
|
||||
},
|
||||
NodeEvent::Message((message, tcp)) => Event {
|
||||
tag: EventTag::Message,
|
||||
tcp: tcp_save(tcp),
|
||||
any_message: anymessage_tag(&message),
|
||||
},
|
||||
NodeEvent::Internal(message) => Event {
|
||||
tag: EventTag::Internal,
|
||||
tcp: 0,
|
||||
any_message: anymessage_tag(&message),
|
||||
},
|
||||
NodeEvent::WakeTimeout(_) => {
|
||||
// can't happen
|
||||
unreachable!()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[no_mangle]
|
||||
pub extern "C" fn sim_now() -> i64 {
|
||||
os().now() as i64
|
||||
}
|
||||
|
||||
#[no_mangle]
|
||||
pub extern "C" fn sim_exit(code: i32, msg: *const u8) {
|
||||
trace!("sim_exit({}, {:?})", code, msg);
|
||||
sim_set_result(code, msg);
|
||||
|
||||
// I tried to make use of pthread_exit, but it doesn't work.
|
||||
// https://github.com/rust-lang/unsafe-code-guidelines/issues/211
|
||||
// unsafe { libc::pthread_exit(std::ptr::null_mut()) };
|
||||
|
||||
// https://doc.rust-lang.org/nomicon/unwinding.html
|
||||
// Everyone on the internet saying this is UB, but it works for me,
|
||||
// so I'm going to use it for now.
|
||||
panic!("sim_exit() called from C code")
|
||||
}
|
||||
|
||||
#[no_mangle]
|
||||
pub extern "C" fn sim_set_result(code: i32, msg: *const u8) {
|
||||
let msg = unsafe { CStr::from_ptr(msg as *const i8) };
|
||||
let msg = msg.to_string_lossy().into_owned();
|
||||
debug!("sim_set_result({}, {:?})", code, msg);
|
||||
os().set_result(code, msg);
|
||||
}
|
||||
|
||||
#[no_mangle]
|
||||
pub extern "C" fn sim_log_event(msg: *const i8) {
|
||||
let msg = unsafe { CStr::from_ptr(msg) };
|
||||
let msg = msg.to_string_lossy().into_owned();
|
||||
debug!("sim_log_event({:?})", msg);
|
||||
os().log_event(msg);
|
||||
}
|
||||
114
libs/walproposer/src/sim_proto.rs
Normal file
114
libs/walproposer/src/sim_proto.rs
Normal file
@@ -0,0 +1,114 @@
|
||||
use safekeeper::simlib::proto::{AnyMessage, ReplCell};
|
||||
use std::{cell::RefCell, ffi::c_char};
|
||||
|
||||
pub(crate) fn anymessage_tag(msg: &AnyMessage) -> AnyMessageTag {
|
||||
match msg {
|
||||
AnyMessage::None => AnyMessageTag::None,
|
||||
AnyMessage::InternalConnect => AnyMessageTag::InternalConnect,
|
||||
AnyMessage::Just32(_) => AnyMessageTag::Just32,
|
||||
AnyMessage::ReplCell(_) => AnyMessageTag::ReplCell,
|
||||
AnyMessage::Bytes(_) => AnyMessageTag::Bytes,
|
||||
AnyMessage::LSN(_) => AnyMessageTag::LSN,
|
||||
}
|
||||
}
|
||||
|
||||
thread_local! {
|
||||
pub static MESSAGE_BUF: RefCell<AnyMessage> = RefCell::new(AnyMessage::None);
|
||||
}
|
||||
|
||||
#[no_mangle]
|
||||
/// Get tag of the current message.
|
||||
pub extern "C" fn sim_msg_tag() -> AnyMessageTag {
|
||||
MESSAGE_BUF.with(|cell| anymessage_tag(&*cell.borrow()))
|
||||
}
|
||||
|
||||
#[no_mangle]
|
||||
/// Read AnyMessage::Just32 message.
|
||||
pub extern "C" fn sim_msg_get_just_u32(val: &mut u32) {
|
||||
MESSAGE_BUF.with(|cell| match &*cell.borrow() {
|
||||
AnyMessage::Just32(v) => {
|
||||
*val = *v;
|
||||
}
|
||||
_ => panic!("expected Just32 message"),
|
||||
});
|
||||
}
|
||||
|
||||
#[no_mangle]
|
||||
/// Read AnyMessage::LSN message.
|
||||
pub extern "C" fn sim_msg_get_lsn(val: &mut u64) {
|
||||
MESSAGE_BUF.with(|cell| match &*cell.borrow() {
|
||||
AnyMessage::LSN(v) => {
|
||||
*val = *v;
|
||||
}
|
||||
_ => panic!("expected LSN message"),
|
||||
});
|
||||
}
|
||||
|
||||
#[no_mangle]
|
||||
/// Write AnyMessage::ReplCell message.
|
||||
pub extern "C" fn sim_msg_set_repl_cell(value: u32, client_id: u32, seqno: u32) {
|
||||
MESSAGE_BUF.with(|cell| {
|
||||
*cell.borrow_mut() = AnyMessage::ReplCell(ReplCell {
|
||||
value,
|
||||
client_id,
|
||||
seqno,
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
#[no_mangle]
|
||||
/// Write AnyMessage::Bytes message.
|
||||
pub extern "C" fn sim_msg_set_bytes(bytes: *const c_char, len: usize) {
|
||||
MESSAGE_BUF.with(|cell| {
|
||||
// copy bytes to a Rust Vec
|
||||
let mut v: Vec<u8> = Vec::with_capacity(len);
|
||||
unsafe {
|
||||
v.set_len(len);
|
||||
std::ptr::copy_nonoverlapping(bytes as *const u8, v.as_mut_ptr(), len);
|
||||
}
|
||||
*cell.borrow_mut() = AnyMessage::Bytes(v.into());
|
||||
});
|
||||
}
|
||||
|
||||
#[no_mangle]
|
||||
/// Read AnyMessage::Bytes message.
|
||||
pub extern "C" fn sim_msg_get_bytes(len: *mut usize) -> *const c_char {
|
||||
MESSAGE_BUF.with(|cell| match &*cell.borrow() {
|
||||
AnyMessage::Bytes(v) => {
|
||||
unsafe {
|
||||
*len = v.len();
|
||||
v.as_ptr() as *const i8
|
||||
}
|
||||
}
|
||||
_ => panic!("expected Bytes message"),
|
||||
})
|
||||
}
|
||||
|
||||
#[repr(C)]
|
||||
/// Event returned by epoll_recv.
|
||||
pub struct Event {
|
||||
pub tag: EventTag,
|
||||
pub tcp: i64,
|
||||
pub any_message: AnyMessageTag,
|
||||
}
|
||||
|
||||
#[repr(u8)]
|
||||
/// List of all possible NodeEvent.
|
||||
pub enum EventTag {
|
||||
Timeout,
|
||||
Accept,
|
||||
Closed,
|
||||
Message,
|
||||
Internal,
|
||||
}
|
||||
|
||||
#[repr(u8)]
|
||||
/// List of all possible AnyMessage.
|
||||
pub enum AnyMessageTag {
|
||||
None,
|
||||
InternalConnect,
|
||||
Just32,
|
||||
ReplCell,
|
||||
Bytes,
|
||||
LSN,
|
||||
}
|
||||
88
libs/walproposer/src/simtest/disk.rs
Normal file
88
libs/walproposer/src/simtest/disk.rs
Normal file
@@ -0,0 +1,88 @@
|
||||
use std::collections::HashMap;
|
||||
use std::sync::Arc;
|
||||
|
||||
use safekeeper::safekeeper::SafeKeeperState;
|
||||
use safekeeper::simlib::sync::Mutex;
|
||||
use utils::id::TenantTimelineId;
|
||||
|
||||
pub struct Disk {
|
||||
pub timelines: Mutex<HashMap<TenantTimelineId, Arc<TimelineDisk>>>,
|
||||
}
|
||||
|
||||
impl Disk {
|
||||
pub fn new() -> Self {
|
||||
Disk {
|
||||
timelines: Mutex::new(HashMap::new()),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn put_state(&self, ttid: &TenantTimelineId, state: SafeKeeperState) -> Arc<TimelineDisk> {
|
||||
self.timelines
|
||||
.lock()
|
||||
.entry(ttid.clone())
|
||||
.and_modify(|e| {
|
||||
let mut mu = e.state.lock();
|
||||
*mu = state.clone();
|
||||
})
|
||||
.or_insert_with(|| {
|
||||
Arc::new(TimelineDisk {
|
||||
state: Mutex::new(state),
|
||||
wal: Mutex::new(BlockStorage::new()),
|
||||
})
|
||||
})
|
||||
.clone()
|
||||
}
|
||||
}
|
||||
|
||||
pub struct TimelineDisk {
|
||||
pub state: Mutex<SafeKeeperState>,
|
||||
pub wal: Mutex<BlockStorage>,
|
||||
}
|
||||
|
||||
const BLOCK_SIZE: usize = 8192;
|
||||
|
||||
pub struct BlockStorage {
|
||||
blocks: HashMap<u64, [u8; BLOCK_SIZE]>,
|
||||
}
|
||||
|
||||
impl BlockStorage {
|
||||
pub fn new() -> Self {
|
||||
BlockStorage {
|
||||
blocks: HashMap::new(),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn read(&self, pos: u64, buf: &mut [u8]) {
|
||||
let mut buf_offset = 0;
|
||||
let mut storage_pos = pos;
|
||||
while buf_offset < buf.len() {
|
||||
let block_id = storage_pos / BLOCK_SIZE as u64;
|
||||
let block = self.blocks.get(&block_id).unwrap_or(&[0; BLOCK_SIZE]);
|
||||
let block_offset = storage_pos % BLOCK_SIZE as u64;
|
||||
let block_len = BLOCK_SIZE as u64 - block_offset;
|
||||
let buf_len = buf.len() - buf_offset;
|
||||
let copy_len = std::cmp::min(block_len as usize, buf_len);
|
||||
buf[buf_offset..buf_offset + copy_len]
|
||||
.copy_from_slice(&block[block_offset as usize..block_offset as usize + copy_len]);
|
||||
buf_offset += copy_len;
|
||||
storage_pos += copy_len as u64;
|
||||
}
|
||||
}
|
||||
|
||||
pub fn write(&mut self, pos: u64, buf: &[u8]) {
|
||||
let mut buf_offset = 0;
|
||||
let mut storage_pos = pos;
|
||||
while buf_offset < buf.len() {
|
||||
let block_id = storage_pos / BLOCK_SIZE as u64;
|
||||
let block = self.blocks.entry(block_id).or_insert([0; BLOCK_SIZE]);
|
||||
let block_offset = storage_pos % BLOCK_SIZE as u64;
|
||||
let block_len = BLOCK_SIZE as u64 - block_offset;
|
||||
let buf_len = buf.len() - buf_offset;
|
||||
let copy_len = std::cmp::min(block_len as usize, buf_len);
|
||||
block[block_offset as usize..block_offset as usize + copy_len]
|
||||
.copy_from_slice(&buf[buf_offset..buf_offset + copy_len]);
|
||||
buf_offset += copy_len;
|
||||
storage_pos += copy_len as u64
|
||||
}
|
||||
}
|
||||
}
|
||||
61
libs/walproposer/src/simtest/log.rs
Normal file
61
libs/walproposer/src/simtest/log.rs
Normal file
@@ -0,0 +1,61 @@
|
||||
use std::{sync::Arc, fmt};
|
||||
|
||||
use safekeeper::simlib::{world::World, sync::Mutex};
|
||||
use tracing_subscriber::fmt::{time::FormatTime, format::Writer};
|
||||
use utils::logging;
|
||||
|
||||
use crate::bindings;
|
||||
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct SimClock {
|
||||
world_ptr: Arc<Mutex<Option<Arc<World>>>>,
|
||||
}
|
||||
|
||||
impl Default for SimClock {
|
||||
fn default() -> Self {
|
||||
SimClock {
|
||||
world_ptr: Arc::new(Mutex::new(None)),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl SimClock {
|
||||
pub fn set_world(&self, world: Arc<World>) {
|
||||
*self.world_ptr.lock() = Some(world);
|
||||
}
|
||||
}
|
||||
|
||||
impl FormatTime for SimClock {
|
||||
fn format_time(&self, w: &mut Writer<'_>) -> fmt::Result {
|
||||
let world = self.world_ptr.lock().clone();
|
||||
|
||||
if let Some(world) = world {
|
||||
let now = world.now();
|
||||
write!(w, "[{}]", now)
|
||||
} else {
|
||||
write!(w, "[?]")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub fn init_logger() -> SimClock {
|
||||
let debug_enabled = unsafe { bindings::debug_enabled };
|
||||
|
||||
let clock = SimClock::default();
|
||||
let base_logger = tracing_subscriber::fmt()
|
||||
.with_target(false)
|
||||
.with_timer(clock.clone())
|
||||
.with_ansi(true)
|
||||
.with_max_level(match debug_enabled {
|
||||
true => tracing::Level::DEBUG,
|
||||
false => tracing::Level::INFO,
|
||||
})
|
||||
.with_writer(std::io::stdout);
|
||||
base_logger.init();
|
||||
|
||||
// logging::replace_panic_hook_with_tracing_panic_hook().forget();
|
||||
std::panic::set_hook(Box::new(|_| {}));
|
||||
|
||||
clock
|
||||
}
|
||||
11
libs/walproposer/src/simtest/mod.rs
Normal file
11
libs/walproposer/src/simtest/mod.rs
Normal file
@@ -0,0 +1,11 @@
|
||||
#[cfg(test)]
|
||||
pub mod simple_client;
|
||||
|
||||
#[cfg(test)]
|
||||
pub mod wp_sk;
|
||||
|
||||
pub mod disk;
|
||||
pub mod safekeeper;
|
||||
pub mod storage;
|
||||
pub mod log;
|
||||
pub mod util;
|
||||
372
libs/walproposer/src/simtest/safekeeper.rs
Normal file
372
libs/walproposer/src/simtest/safekeeper.rs
Normal file
@@ -0,0 +1,372 @@
|
||||
//! Safekeeper communication endpoint to WAL proposer (compute node).
|
||||
//! Gets messages from the network, passes them down to consensus module and
|
||||
//! sends replies back.
|
||||
|
||||
use std::{collections::HashMap, path::PathBuf, sync::Arc, time::Duration};
|
||||
|
||||
use anyhow::{anyhow, bail, Result};
|
||||
use bytes::{Bytes, BytesMut};
|
||||
use hyper::Uri;
|
||||
use log::info;
|
||||
use safekeeper::{
|
||||
safekeeper::{
|
||||
ProposerAcceptorMessage, SafeKeeper, SafeKeeperState, ServerInfo, UNKNOWN_SERVER_VERSION,
|
||||
},
|
||||
simlib::{network::TCP, node_os::NodeOs, proto::AnyMessage, world::NodeEvent},
|
||||
timeline::TimelineError,
|
||||
SafeKeeperConf, wal_storage::Storage,
|
||||
};
|
||||
use tracing::{debug, info_span};
|
||||
use utils::{
|
||||
id::{NodeId, TenantId, TenantTimelineId, TimelineId},
|
||||
lsn::Lsn,
|
||||
};
|
||||
|
||||
use crate::simtest::storage::DiskStateStorage;
|
||||
|
||||
use super::{
|
||||
disk::{Disk, TimelineDisk},
|
||||
storage::DiskWALStorage,
|
||||
};
|
||||
|
||||
struct ConnState {
|
||||
tcp: TCP,
|
||||
|
||||
greeting: bool,
|
||||
ttid: TenantTimelineId,
|
||||
flush_pending: bool,
|
||||
}
|
||||
|
||||
struct SharedState {
|
||||
sk: SafeKeeper<DiskStateStorage, DiskWALStorage>,
|
||||
disk: Arc<TimelineDisk>,
|
||||
}
|
||||
|
||||
struct GlobalMap {
|
||||
timelines: HashMap<TenantTimelineId, SharedState>,
|
||||
conf: SafeKeeperConf,
|
||||
disk: Arc<Disk>,
|
||||
}
|
||||
|
||||
impl GlobalMap {
|
||||
fn new(disk: Arc<Disk>, conf: SafeKeeperConf) -> Result<Self> {
|
||||
let mut timelines = HashMap::new();
|
||||
|
||||
for (&ttid, disk) in disk.timelines.lock().iter() {
|
||||
debug!("loading timeline {}", ttid);
|
||||
let state = disk.state.lock().clone();
|
||||
|
||||
if state.server.wal_seg_size == 0 {
|
||||
bail!(TimelineError::UninitializedWalSegSize(ttid));
|
||||
}
|
||||
|
||||
if state.server.pg_version == UNKNOWN_SERVER_VERSION {
|
||||
bail!(TimelineError::UninitialinzedPgVersion(ttid));
|
||||
}
|
||||
|
||||
if state.commit_lsn < state.local_start_lsn {
|
||||
bail!(
|
||||
"commit_lsn {} is higher than local_start_lsn {}",
|
||||
state.commit_lsn,
|
||||
state.local_start_lsn
|
||||
);
|
||||
}
|
||||
|
||||
let control_store = DiskStateStorage::new(disk.clone());
|
||||
let wal_store = DiskWALStorage::new(disk.clone(), &control_store)?;
|
||||
|
||||
let sk = SafeKeeper::new(control_store, wal_store, conf.my_id)?;
|
||||
timelines.insert(
|
||||
ttid.clone(),
|
||||
SharedState {
|
||||
sk,
|
||||
disk: disk.clone(),
|
||||
},
|
||||
);
|
||||
}
|
||||
|
||||
Ok(Self {
|
||||
timelines,
|
||||
conf,
|
||||
disk,
|
||||
})
|
||||
}
|
||||
|
||||
fn create(&mut self, ttid: TenantTimelineId, server_info: ServerInfo) -> Result<()> {
|
||||
if self.timelines.contains_key(&ttid) {
|
||||
bail!("timeline {} already exists", ttid);
|
||||
}
|
||||
|
||||
debug!("creating new timeline {}", ttid);
|
||||
|
||||
let commit_lsn = Lsn::INVALID;
|
||||
let local_start_lsn = Lsn::INVALID;
|
||||
|
||||
// TODO: load state from in-memory storage
|
||||
let state = SafeKeeperState::new(&ttid, server_info, vec![], commit_lsn, local_start_lsn);
|
||||
|
||||
if state.server.wal_seg_size == 0 {
|
||||
bail!(TimelineError::UninitializedWalSegSize(ttid));
|
||||
}
|
||||
|
||||
if state.server.pg_version == UNKNOWN_SERVER_VERSION {
|
||||
bail!(TimelineError::UninitialinzedPgVersion(ttid));
|
||||
}
|
||||
|
||||
if state.commit_lsn < state.local_start_lsn {
|
||||
bail!(
|
||||
"commit_lsn {} is higher than local_start_lsn {}",
|
||||
state.commit_lsn,
|
||||
state.local_start_lsn
|
||||
);
|
||||
}
|
||||
|
||||
let disk_timeline = self.disk.put_state(&ttid, state);
|
||||
let control_store = DiskStateStorage::new(disk_timeline.clone());
|
||||
let wal_store = DiskWALStorage::new(disk_timeline.clone(), &control_store)?;
|
||||
|
||||
let sk = SafeKeeper::new(control_store, wal_store, self.conf.my_id)?;
|
||||
|
||||
self.timelines.insert(
|
||||
ttid.clone(),
|
||||
SharedState {
|
||||
sk,
|
||||
disk: disk_timeline,
|
||||
},
|
||||
);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn get(&mut self, ttid: &TenantTimelineId) -> &mut SharedState {
|
||||
self.timelines.get_mut(ttid).expect("timeline must exist")
|
||||
}
|
||||
|
||||
fn has_tli(&self, ttid: &TenantTimelineId) -> bool {
|
||||
self.timelines.contains_key(ttid)
|
||||
}
|
||||
}
|
||||
|
||||
pub fn run_server(os: NodeOs, disk: Arc<Disk>) -> Result<()> {
|
||||
let _enter = info_span!("safekeeper", id = os.id()).entered();
|
||||
debug!("started server");
|
||||
os.log_event("started;safekeeper".to_owned());
|
||||
let conf = SafeKeeperConf {
|
||||
workdir: PathBuf::from("."),
|
||||
my_id: NodeId(os.id() as u64),
|
||||
listen_pg_addr: String::new(),
|
||||
listen_http_addr: String::new(),
|
||||
no_sync: false,
|
||||
broker_endpoint: "/".parse::<Uri>().unwrap(),
|
||||
broker_keepalive_interval: Duration::from_secs(0),
|
||||
heartbeat_timeout: Duration::from_secs(0),
|
||||
remote_storage: None,
|
||||
max_offloader_lag_bytes: 0,
|
||||
backup_runtime_threads: None,
|
||||
wal_backup_enabled: false,
|
||||
auth: None,
|
||||
};
|
||||
|
||||
let mut global = GlobalMap::new(disk, conf.clone())?;
|
||||
let mut conns: HashMap<i64, ConnState> = HashMap::new();
|
||||
|
||||
for (&ttid, shared_state) in global.timelines.iter_mut() {
|
||||
let flush_lsn = shared_state.sk.wal_store.flush_lsn();
|
||||
let commit_lsn = shared_state.sk.state.commit_lsn;
|
||||
os.log_event(format!("tli_loaded;{};{}", flush_lsn.0, commit_lsn.0));
|
||||
}
|
||||
|
||||
let epoll = os.epoll();
|
||||
loop {
|
||||
// waiting for the next message
|
||||
let mut next_event = Some(epoll.recv());
|
||||
|
||||
loop {
|
||||
let event = match next_event {
|
||||
Some(event) => event,
|
||||
None => break,
|
||||
};
|
||||
|
||||
match event {
|
||||
NodeEvent::Accept(tcp) => {
|
||||
conns.insert(
|
||||
tcp.id(),
|
||||
ConnState {
|
||||
tcp,
|
||||
greeting: false,
|
||||
ttid: TenantTimelineId::empty(),
|
||||
flush_pending: false,
|
||||
},
|
||||
);
|
||||
}
|
||||
NodeEvent::Message((msg, tcp)) => {
|
||||
let conn = conns.get_mut(&tcp.id());
|
||||
if let Some(conn) = conn {
|
||||
let res = conn.process_any(msg, &mut global);
|
||||
if res.is_err() {
|
||||
debug!("conn {:?} error: {:#}", tcp, res.unwrap_err());
|
||||
conns.remove(&tcp.id());
|
||||
}
|
||||
} else {
|
||||
debug!("conn {:?} was closed, dropping msg {:?}", tcp, msg);
|
||||
}
|
||||
}
|
||||
NodeEvent::Internal(_) => {}
|
||||
NodeEvent::Closed(_) => {}
|
||||
NodeEvent::WakeTimeout(_) => {}
|
||||
}
|
||||
|
||||
// TODO: make simulator support multiple events per tick
|
||||
next_event = epoll.try_recv();
|
||||
}
|
||||
|
||||
conns.retain(|_, conn| {
|
||||
let res = conn.flush(&mut global);
|
||||
if res.is_err() {
|
||||
debug!("conn {:?} error: {:?}", conn.tcp, res);
|
||||
}
|
||||
res.is_ok()
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
impl ConnState {
|
||||
fn process_any(&mut self, any: AnyMessage, global: &mut GlobalMap) -> Result<()> {
|
||||
if let AnyMessage::Bytes(copy_data) = any {
|
||||
let repl_prefix = b"START_REPLICATION ";
|
||||
if !self.greeting && copy_data.starts_with(repl_prefix) {
|
||||
self.process_start_replication(copy_data.slice(repl_prefix.len()..), global)?;
|
||||
bail!("finished processing START_REPLICATION")
|
||||
}
|
||||
|
||||
let msg = ProposerAcceptorMessage::parse(copy_data)?;
|
||||
debug!("got msg: {:?}", msg);
|
||||
return self.process(msg, global);
|
||||
} else {
|
||||
bail!("unexpected message, expected AnyMessage::Bytes");
|
||||
}
|
||||
}
|
||||
|
||||
fn process_start_replication(
|
||||
&mut self,
|
||||
copy_data: Bytes,
|
||||
global: &mut GlobalMap,
|
||||
) -> Result<()> {
|
||||
// format is "<tenant_id> <timeline_id> <start_lsn> <end_lsn>"
|
||||
let str = String::from_utf8(copy_data.to_vec())?;
|
||||
|
||||
let mut parts = str.split(' ');
|
||||
let tenant_id = parts.next().unwrap().parse::<TenantId>()?;
|
||||
let timeline_id = parts.next().unwrap().parse::<TimelineId>()?;
|
||||
let start_lsn = parts.next().unwrap().parse::<u64>()?;
|
||||
let end_lsn = parts.next().unwrap().parse::<u64>()?;
|
||||
|
||||
let ttid = TenantTimelineId::new(tenant_id, timeline_id);
|
||||
let shared_state = global.get(&ttid);
|
||||
|
||||
// read bytes from start_lsn to end_lsn
|
||||
let mut buf = vec![0; (end_lsn - start_lsn) as usize];
|
||||
shared_state.disk.wal.lock().read(start_lsn, &mut buf);
|
||||
|
||||
// send bytes to the client
|
||||
self.tcp.send(AnyMessage::Bytes(Bytes::from(buf)));
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn init_timeline(
|
||||
&mut self,
|
||||
ttid: TenantTimelineId,
|
||||
server_info: ServerInfo,
|
||||
global: &mut GlobalMap,
|
||||
) -> Result<()> {
|
||||
self.ttid = ttid;
|
||||
if global.has_tli(&ttid) {
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
global.create(ttid, server_info)
|
||||
}
|
||||
|
||||
fn process(&mut self, msg: ProposerAcceptorMessage, global: &mut GlobalMap) -> Result<()> {
|
||||
if !self.greeting {
|
||||
self.greeting = true;
|
||||
|
||||
match msg {
|
||||
ProposerAcceptorMessage::Greeting(ref greeting) => {
|
||||
debug!(
|
||||
"start handshake with walproposer {:?}",
|
||||
self.tcp,
|
||||
);
|
||||
let server_info = ServerInfo {
|
||||
pg_version: greeting.pg_version,
|
||||
system_id: greeting.system_id,
|
||||
wal_seg_size: greeting.wal_seg_size,
|
||||
};
|
||||
let ttid = TenantTimelineId::new(greeting.tenant_id, greeting.timeline_id);
|
||||
self.init_timeline(ttid, server_info, global)?
|
||||
}
|
||||
_ => {
|
||||
bail!("unexpected message {msg:?} instead of greeting");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
let tli = global.get(&self.ttid);
|
||||
|
||||
match msg {
|
||||
ProposerAcceptorMessage::AppendRequest(append_request) => {
|
||||
self.flush_pending = true;
|
||||
self.process_sk_msg(
|
||||
tli,
|
||||
&ProposerAcceptorMessage::NoFlushAppendRequest(append_request),
|
||||
)?;
|
||||
}
|
||||
other => {
|
||||
self.process_sk_msg(tli, &other)?;
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Process FlushWAL if needed.
|
||||
// TODO: add extra flushes, to verify that extra flushes don't break anything
|
||||
fn flush(&mut self, global: &mut GlobalMap) -> Result<()> {
|
||||
if !self.flush_pending {
|
||||
return Ok(());
|
||||
}
|
||||
self.flush_pending = false;
|
||||
let shared_state = global.get(&self.ttid);
|
||||
self.process_sk_msg(shared_state, &ProposerAcceptorMessage::FlushWAL)
|
||||
}
|
||||
|
||||
/// Make safekeeper process a message and send a reply to the TCP
|
||||
fn process_sk_msg(
|
||||
&mut self,
|
||||
shared_state: &mut SharedState,
|
||||
msg: &ProposerAcceptorMessage,
|
||||
) -> Result<()> {
|
||||
let mut reply = shared_state.sk.process_msg(msg)?;
|
||||
if let Some(reply) = &mut reply {
|
||||
// // if this is AppendResponse, fill in proper hot standby feedback and disk consistent lsn
|
||||
// if let AcceptorProposerMessage::AppendResponse(ref mut resp) = reply {
|
||||
// // TODO:
|
||||
// }
|
||||
|
||||
let mut buf = BytesMut::with_capacity(128);
|
||||
reply.serialize(&mut buf)?;
|
||||
|
||||
self.tcp.send(AnyMessage::Bytes(buf.into()));
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
impl Drop for ConnState {
|
||||
fn drop(&mut self) {
|
||||
debug!("dropping conn: {:?}", self.tcp);
|
||||
if !std::thread::panicking() {
|
||||
self.tcp.close();
|
||||
}
|
||||
// TODO: clean up non-fsynced WAL
|
||||
}
|
||||
}
|
||||
38
libs/walproposer/src/simtest/simple_client.rs
Normal file
38
libs/walproposer/src/simtest/simple_client.rs
Normal file
@@ -0,0 +1,38 @@
|
||||
use std::sync::Arc;
|
||||
|
||||
use safekeeper::{
|
||||
simlib::{
|
||||
network::{Delay, NetworkOptions},
|
||||
world::World,
|
||||
},
|
||||
simtest::{start_simulation, Options},
|
||||
};
|
||||
|
||||
use crate::{bindings::RunClientC, c_context};
|
||||
|
||||
#[test]
|
||||
fn run_rust_c_test() {
|
||||
let delay = Delay {
|
||||
min: 1,
|
||||
max: 5,
|
||||
fail_prob: 0.5,
|
||||
};
|
||||
|
||||
let network = NetworkOptions {
|
||||
keepalive_timeout: Some(50),
|
||||
connect_delay: delay.clone(),
|
||||
send_delay: delay.clone(),
|
||||
};
|
||||
|
||||
let u32_data: [u32; 5] = [1, 2, 3, 4, 5];
|
||||
|
||||
let world = Arc::new(World::new(1337, Arc::new(network), c_context()));
|
||||
start_simulation(Options {
|
||||
world,
|
||||
time_limit: 1_000_000,
|
||||
client_fn: Box::new(move |_, server_id| unsafe {
|
||||
RunClientC(server_id);
|
||||
}),
|
||||
u32_data,
|
||||
});
|
||||
}
|
||||
234
libs/walproposer/src/simtest/storage.rs
Normal file
234
libs/walproposer/src/simtest/storage.rs
Normal file
@@ -0,0 +1,234 @@
|
||||
use std::{ops::Deref, sync::Arc};
|
||||
|
||||
use anyhow::Result;
|
||||
use bytes::{Buf, BytesMut};
|
||||
use log::{debug, info};
|
||||
use postgres_ffi::{waldecoder::WalStreamDecoder, XLogSegNo};
|
||||
use safekeeper::{control_file, safekeeper::SafeKeeperState, wal_storage};
|
||||
use utils::lsn::Lsn;
|
||||
|
||||
use super::disk::TimelineDisk;
|
||||
|
||||
pub struct DiskStateStorage {
|
||||
persisted_state: SafeKeeperState,
|
||||
disk: Arc<TimelineDisk>,
|
||||
}
|
||||
|
||||
impl DiskStateStorage {
|
||||
pub fn new(disk: Arc<TimelineDisk>) -> Self {
|
||||
let guard = disk.state.lock();
|
||||
let state = guard.clone();
|
||||
drop(guard);
|
||||
DiskStateStorage {
|
||||
persisted_state: state,
|
||||
disk,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl control_file::Storage for DiskStateStorage {
|
||||
fn persist(&mut self, s: &SafeKeeperState) -> Result<()> {
|
||||
self.persisted_state = s.clone();
|
||||
*self.disk.state.lock() = s.clone();
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
impl Deref for DiskStateStorage {
|
||||
type Target = SafeKeeperState;
|
||||
|
||||
fn deref(&self) -> &Self::Target {
|
||||
&self.persisted_state
|
||||
}
|
||||
}
|
||||
|
||||
pub struct DummyWalStore {
|
||||
lsn: Lsn,
|
||||
}
|
||||
|
||||
impl DummyWalStore {
|
||||
pub fn new() -> Self {
|
||||
DummyWalStore { lsn: Lsn::INVALID }
|
||||
}
|
||||
}
|
||||
|
||||
impl wal_storage::Storage for DummyWalStore {
|
||||
fn flush_lsn(&self) -> Lsn {
|
||||
self.lsn
|
||||
}
|
||||
|
||||
fn write_wal(&mut self, startpos: Lsn, buf: &[u8]) -> Result<()> {
|
||||
self.lsn = startpos + buf.len() as u64;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn truncate_wal(&mut self, end_pos: Lsn) -> Result<()> {
|
||||
self.lsn = end_pos;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn flush_wal(&mut self) -> Result<()> {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn remove_up_to(&self) -> Box<dyn Fn(XLogSegNo) -> Result<()>> {
|
||||
Box::new(move |_segno_up_to: XLogSegNo| Ok(()))
|
||||
}
|
||||
|
||||
fn get_metrics(&self) -> safekeeper::metrics::WalStorageMetrics {
|
||||
safekeeper::metrics::WalStorageMetrics::default()
|
||||
}
|
||||
}
|
||||
|
||||
pub struct DiskWALStorage {
|
||||
/// Written to disk, but possibly still in the cache and not fully persisted.
|
||||
/// Also can be ahead of record_lsn, if happen to be in the middle of a WAL record.
|
||||
write_lsn: Lsn,
|
||||
|
||||
/// The LSN of the last WAL record written to disk. Still can be not fully flushed.
|
||||
write_record_lsn: Lsn,
|
||||
|
||||
/// The LSN of the last WAL record flushed to disk.
|
||||
flush_record_lsn: Lsn,
|
||||
|
||||
/// Decoder is required for detecting boundaries of WAL records.
|
||||
decoder: WalStreamDecoder,
|
||||
|
||||
unflushed_bytes: BytesMut,
|
||||
|
||||
disk: Arc<TimelineDisk>,
|
||||
}
|
||||
|
||||
impl DiskWALStorage {
|
||||
pub fn new(disk: Arc<TimelineDisk>, state: &SafeKeeperState) -> Result<Self> {
|
||||
let write_lsn = if state.commit_lsn == Lsn(0) {
|
||||
Lsn(0)
|
||||
} else {
|
||||
Self::find_end_of_wal(disk.clone(), state.commit_lsn)?
|
||||
};
|
||||
|
||||
let flush_lsn = write_lsn;
|
||||
Ok(DiskWALStorage {
|
||||
write_lsn,
|
||||
write_record_lsn: flush_lsn,
|
||||
flush_record_lsn: flush_lsn,
|
||||
decoder: WalStreamDecoder::new(flush_lsn, 15),
|
||||
unflushed_bytes: BytesMut::new(),
|
||||
disk,
|
||||
})
|
||||
}
|
||||
|
||||
fn find_end_of_wal(disk: Arc<TimelineDisk>, start_lsn: Lsn) -> Result<Lsn> {
|
||||
let mut buf = [0; 8192];
|
||||
let mut pos = start_lsn.0;
|
||||
let mut decoder = WalStreamDecoder::new(start_lsn, 15);
|
||||
let mut result = start_lsn;
|
||||
loop {
|
||||
disk.wal.lock().read(pos, &mut buf);
|
||||
pos += buf.len() as u64;
|
||||
decoder.feed_bytes(&buf);
|
||||
|
||||
loop {
|
||||
match decoder.poll_decode() {
|
||||
Ok(Some(record)) => result = record.0,
|
||||
Err(e) => {
|
||||
debug!(
|
||||
"find_end_of_wal reached end at {:?}, decode error: {:?}",
|
||||
result, e
|
||||
);
|
||||
return Ok(result);
|
||||
}
|
||||
Ok(None) => break, // need more data
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl wal_storage::Storage for DiskWALStorage {
|
||||
fn flush_lsn(&self) -> Lsn {
|
||||
self.flush_record_lsn
|
||||
}
|
||||
|
||||
fn write_wal(&mut self, startpos: Lsn, buf: &[u8]) -> Result<()> {
|
||||
if self.write_lsn != startpos {
|
||||
panic!("write_wal called with wrong startpos");
|
||||
}
|
||||
|
||||
self.unflushed_bytes.extend_from_slice(buf);
|
||||
self.write_lsn += buf.len() as u64;
|
||||
|
||||
if self.decoder.available() != startpos {
|
||||
info!(
|
||||
"restart decoder from {} to {}",
|
||||
self.decoder.available(),
|
||||
startpos,
|
||||
);
|
||||
self.decoder = WalStreamDecoder::new(startpos, 15);
|
||||
}
|
||||
self.decoder.feed_bytes(buf);
|
||||
loop {
|
||||
match self.decoder.poll_decode()? {
|
||||
None => break, // no full record yet
|
||||
Some((lsn, _rec)) => {
|
||||
self.write_record_lsn = lsn;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn truncate_wal(&mut self, end_pos: Lsn) -> Result<()> {
|
||||
if self.write_lsn != Lsn(0) && end_pos > self.write_lsn {
|
||||
panic!(
|
||||
"truncate_wal called on non-written WAL, write_lsn={}, end_pos={}",
|
||||
self.write_lsn, end_pos
|
||||
);
|
||||
}
|
||||
|
||||
self.flush_wal()?;
|
||||
|
||||
// write zeroes to disk from end_pos until self.write_lsn
|
||||
let buf = [0; 8192];
|
||||
let mut pos = end_pos.0;
|
||||
while pos < self.write_lsn.0 {
|
||||
self.disk.wal.lock().write(pos, &buf);
|
||||
pos += buf.len() as u64;
|
||||
}
|
||||
|
||||
self.write_lsn = end_pos;
|
||||
self.write_record_lsn = end_pos;
|
||||
self.flush_record_lsn = end_pos;
|
||||
self.unflushed_bytes.clear();
|
||||
self.decoder = WalStreamDecoder::new(end_pos, 15);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn flush_wal(&mut self) -> Result<()> {
|
||||
if self.flush_record_lsn == self.write_record_lsn {
|
||||
// no need to do extra flush
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
let num_bytes = self.write_record_lsn.0 - self.flush_record_lsn.0;
|
||||
|
||||
self.disk.wal.lock().write(
|
||||
self.flush_record_lsn.0,
|
||||
&self.unflushed_bytes[..num_bytes as usize],
|
||||
);
|
||||
self.unflushed_bytes.advance(num_bytes as usize);
|
||||
self.flush_record_lsn = self.write_record_lsn;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn remove_up_to(&self) -> Box<dyn Fn(XLogSegNo) -> Result<()>> {
|
||||
Box::new(move |_segno_up_to: XLogSegNo| Ok(()))
|
||||
}
|
||||
|
||||
fn get_metrics(&self) -> safekeeper::metrics::WalStorageMetrics {
|
||||
safekeeper::metrics::WalStorageMetrics::default()
|
||||
}
|
||||
}
|
||||
610
libs/walproposer/src/simtest/util.rs
Normal file
610
libs/walproposer/src/simtest/util.rs
Normal file
@@ -0,0 +1,610 @@
|
||||
use std::{ffi::CString, path::Path, str::FromStr, sync::Arc, collections::HashMap};
|
||||
|
||||
use rand::{Rng, SeedableRng};
|
||||
use safekeeper::simlib::{
|
||||
network::{Delay, NetworkOptions},
|
||||
proto::AnyMessage,
|
||||
time::EmptyEvent,
|
||||
world::World,
|
||||
world::{Node, NodeEvent, SEvent, NodeId},
|
||||
};
|
||||
use tracing::{debug, error, info, warn};
|
||||
use utils::{id::TenantTimelineId, lsn::Lsn};
|
||||
|
||||
use crate::{
|
||||
bindings::{
|
||||
neon_tenant_walproposer, neon_timeline_walproposer, sim_redo_start_lsn, syncSafekeepers,
|
||||
wal_acceptor_connection_timeout, wal_acceptor_reconnect_timeout, wal_acceptors_list,
|
||||
MyInsertRecord, WalProposerCleanup, WalProposerRust,
|
||||
},
|
||||
c_context,
|
||||
simtest::{
|
||||
log::{init_logger, SimClock},
|
||||
safekeeper::run_server,
|
||||
},
|
||||
};
|
||||
|
||||
use super::disk::Disk;
|
||||
|
||||
pub struct SkNode {
|
||||
pub node: Arc<Node>,
|
||||
pub id: u32,
|
||||
pub disk: Arc<Disk>,
|
||||
}
|
||||
|
||||
impl SkNode {
|
||||
pub fn new(node: Arc<Node>) -> Self {
|
||||
let disk = Arc::new(Disk::new());
|
||||
let res = Self {
|
||||
id: node.id,
|
||||
node,
|
||||
disk,
|
||||
};
|
||||
res.launch();
|
||||
res
|
||||
}
|
||||
|
||||
pub fn launch(&self) {
|
||||
let id = self.id;
|
||||
let disk = self.disk.clone();
|
||||
// start the server thread
|
||||
self.node.launch(move |os| {
|
||||
let res = run_server(os, disk);
|
||||
debug!("server {} finished: {:?}", id, res);
|
||||
});
|
||||
}
|
||||
|
||||
pub fn restart(&self) {
|
||||
self.node.crash_stop();
|
||||
self.launch();
|
||||
}
|
||||
}
|
||||
|
||||
pub struct TestConfig {
|
||||
pub network: NetworkOptions,
|
||||
pub timeout: u64,
|
||||
pub clock: Option<SimClock>,
|
||||
}
|
||||
|
||||
impl TestConfig {
|
||||
pub fn new(clock: Option<SimClock>) -> Self {
|
||||
Self {
|
||||
network: NetworkOptions {
|
||||
keepalive_timeout: Some(2000),
|
||||
connect_delay: Delay {
|
||||
min: 1,
|
||||
max: 5,
|
||||
fail_prob: 0.0,
|
||||
},
|
||||
send_delay: Delay {
|
||||
min: 1,
|
||||
max: 5,
|
||||
fail_prob: 0.0,
|
||||
},
|
||||
},
|
||||
timeout: 1_000 * 10,
|
||||
clock,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn start(&self, seed: u64) -> Test {
|
||||
let world = Arc::new(World::new(
|
||||
seed,
|
||||
Arc::new(self.network.clone()),
|
||||
c_context(),
|
||||
));
|
||||
world.register_world();
|
||||
|
||||
if let Some(clock) = &self.clock {
|
||||
clock.set_world(world.clone());
|
||||
}
|
||||
|
||||
let servers = [
|
||||
SkNode::new(world.new_node()),
|
||||
SkNode::new(world.new_node()),
|
||||
SkNode::new(world.new_node()),
|
||||
];
|
||||
|
||||
let server_ids = [servers[0].id, servers[1].id, servers[2].id];
|
||||
|
||||
let safekeepers_guc = server_ids.map(|id| format!("node:{}", id)).join(",");
|
||||
let ttid = TenantTimelineId::generate();
|
||||
|
||||
// wait init for all servers
|
||||
world.await_all();
|
||||
|
||||
// clean up pgdata directory
|
||||
self.init_pgdata();
|
||||
|
||||
Test {
|
||||
world,
|
||||
servers,
|
||||
safekeepers_guc,
|
||||
ttid,
|
||||
timeout: self.timeout,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn init_pgdata(&self) {
|
||||
let pgdata = Path::new("/home/admin/simulator/libs/walproposer/pgdata");
|
||||
if pgdata.exists() {
|
||||
std::fs::remove_dir_all(pgdata).unwrap();
|
||||
}
|
||||
std::fs::create_dir(pgdata).unwrap();
|
||||
|
||||
// create empty pg_wal and pg_notify subdirs
|
||||
std::fs::create_dir(pgdata.join("pg_wal")).unwrap();
|
||||
std::fs::create_dir(pgdata.join("pg_notify")).unwrap();
|
||||
|
||||
// write postgresql.conf
|
||||
let mut conf = std::fs::File::create(pgdata.join("postgresql.conf")).unwrap();
|
||||
let content = "
|
||||
wal_log_hints=off
|
||||
hot_standby=on
|
||||
fsync=off
|
||||
wal_level=replica
|
||||
restart_after_crash=off
|
||||
shared_preload_libraries=neon
|
||||
neon.pageserver_connstring=''
|
||||
neon.tenant_id=cc6e67313d57283bad411600fbf5c142
|
||||
neon.timeline_id=de6fa815c1e45aa61491c3d34c4eb33e
|
||||
synchronous_standby_names=walproposer
|
||||
neon.safekeepers='node:1,node:2,node:3'
|
||||
max_connections=100
|
||||
";
|
||||
|
||||
std::io::Write::write_all(&mut conf, content.as_bytes()).unwrap();
|
||||
}
|
||||
}
|
||||
|
||||
pub struct Test {
|
||||
pub world: Arc<World>,
|
||||
pub servers: [SkNode; 3],
|
||||
pub safekeepers_guc: String,
|
||||
pub ttid: TenantTimelineId,
|
||||
pub timeout: u64,
|
||||
}
|
||||
|
||||
impl Test {
|
||||
fn launch_sync(&self) -> Arc<Node> {
|
||||
let client_node = self.world.new_node();
|
||||
debug!("sync-safekeepers started at node {}", client_node.id);
|
||||
|
||||
// start the client thread
|
||||
let guc = self.safekeepers_guc.clone();
|
||||
let ttid = self.ttid.clone();
|
||||
client_node.launch(move |_| {
|
||||
let list = CString::new(guc).unwrap();
|
||||
|
||||
unsafe {
|
||||
WalProposerCleanup();
|
||||
|
||||
syncSafekeepers = true;
|
||||
wal_acceptors_list = list.into_raw();
|
||||
wal_acceptor_reconnect_timeout = 1000;
|
||||
wal_acceptor_connection_timeout = 5000;
|
||||
neon_tenant_walproposer =
|
||||
CString::new(ttid.tenant_id.to_string()).unwrap().into_raw();
|
||||
neon_timeline_walproposer = CString::new(ttid.timeline_id.to_string())
|
||||
.unwrap()
|
||||
.into_raw();
|
||||
WalProposerRust();
|
||||
}
|
||||
});
|
||||
|
||||
self.world.await_all();
|
||||
|
||||
client_node
|
||||
}
|
||||
|
||||
pub fn sync_safekeepers(&self) -> anyhow::Result<Lsn> {
|
||||
let client_node = self.launch_sync();
|
||||
|
||||
// poll until exit or timeout
|
||||
let time_limit = self.timeout;
|
||||
while self.world.step() && self.world.now() < time_limit && !client_node.is_finished() {}
|
||||
|
||||
if !client_node.is_finished() {
|
||||
anyhow::bail!("timeout or idle stuck");
|
||||
}
|
||||
|
||||
let res = client_node.result.lock().clone();
|
||||
if res.0 != 0 {
|
||||
anyhow::bail!("non-zero exitcode: {:?}", res);
|
||||
}
|
||||
let lsn = Lsn::from_str(&res.1)?;
|
||||
Ok(lsn)
|
||||
}
|
||||
|
||||
pub fn launch_walproposer(&self, lsn: Lsn) -> WalProposer {
|
||||
let client_node = self.world.new_node();
|
||||
|
||||
let lsn = if lsn.0 == 0 {
|
||||
// usual LSN after basebackup
|
||||
Lsn(21623024)
|
||||
} else {
|
||||
lsn
|
||||
};
|
||||
|
||||
// start the client thread
|
||||
let guc = self.safekeepers_guc.clone();
|
||||
let ttid = self.ttid.clone();
|
||||
client_node.launch(move |_| {
|
||||
let list = CString::new(guc).unwrap();
|
||||
|
||||
unsafe {
|
||||
WalProposerCleanup();
|
||||
|
||||
sim_redo_start_lsn = lsn.0;
|
||||
syncSafekeepers = false;
|
||||
wal_acceptors_list = list.into_raw();
|
||||
wal_acceptor_reconnect_timeout = 1000;
|
||||
wal_acceptor_connection_timeout = 5000;
|
||||
neon_tenant_walproposer =
|
||||
CString::new(ttid.tenant_id.to_string()).unwrap().into_raw();
|
||||
neon_timeline_walproposer = CString::new(ttid.timeline_id.to_string())
|
||||
.unwrap()
|
||||
.into_raw();
|
||||
WalProposerRust();
|
||||
}
|
||||
});
|
||||
|
||||
self.world.await_all();
|
||||
|
||||
WalProposer {
|
||||
node: client_node,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn poll_for_duration(&self, duration: u64) {
|
||||
let time_limit = std::cmp::min(self.world.now() + duration, self.timeout);
|
||||
while self.world.step() && self.world.now() < time_limit {}
|
||||
}
|
||||
|
||||
pub fn run_schedule(&self, schedule: &Schedule) -> anyhow::Result<()> {
|
||||
{
|
||||
let empty_event = Box::new(EmptyEvent);
|
||||
|
||||
let now = self.world.now();
|
||||
for (time, _) in schedule {
|
||||
if *time < now {
|
||||
continue;
|
||||
}
|
||||
self.world.schedule(*time - now, empty_event.clone())
|
||||
}
|
||||
}
|
||||
|
||||
let mut wait_node = self.launch_sync();
|
||||
// fake walproposer
|
||||
let mut wp = WalProposer {
|
||||
node: wait_node.clone(),
|
||||
};
|
||||
let mut sync_in_progress = true;
|
||||
|
||||
let mut skipped_tx = 0;
|
||||
let mut started_tx = 0;
|
||||
|
||||
let mut schedule_ptr = 0;
|
||||
|
||||
loop {
|
||||
if sync_in_progress && wait_node.is_finished() {
|
||||
let res = wait_node.result.lock().clone();
|
||||
if res.0 != 0 {
|
||||
warn!("sync non-zero exitcode: {:?}", res);
|
||||
debug!("restarting walproposer");
|
||||
wait_node = self.launch_sync();
|
||||
continue;
|
||||
}
|
||||
let lsn = Lsn::from_str(&res.1)?;
|
||||
debug!("sync-safekeepers finished at LSN {}", lsn);
|
||||
wp = self.launch_walproposer(lsn);
|
||||
wait_node = wp.node.clone();
|
||||
debug!("walproposer started at node {}", wait_node.id);
|
||||
sync_in_progress = false;
|
||||
}
|
||||
|
||||
let now = self.world.now();
|
||||
while schedule_ptr < schedule.len() && schedule[schedule_ptr].0 <= now {
|
||||
if now != schedule[schedule_ptr].0 {
|
||||
warn!("skipped event {:?} at {}", schedule[schedule_ptr], now);
|
||||
}
|
||||
|
||||
let action = &schedule[schedule_ptr].1;
|
||||
match action {
|
||||
TestAction::WriteTx(size) => {
|
||||
if !sync_in_progress && !wait_node.is_finished() {
|
||||
started_tx += *size;
|
||||
wp.write_tx(*size);
|
||||
debug!("written {} transactions", size);
|
||||
} else {
|
||||
skipped_tx += size;
|
||||
debug!("skipped {} transactions", size);
|
||||
}
|
||||
}
|
||||
TestAction::RestartSafekeeper(id) => {
|
||||
debug!("restarting safekeeper {}", id);
|
||||
self.servers[*id as usize].restart();
|
||||
}
|
||||
TestAction::RestartWalProposer => {
|
||||
debug!("restarting walproposer");
|
||||
wait_node.crash_stop();
|
||||
sync_in_progress = true;
|
||||
wait_node = self.launch_sync();
|
||||
}
|
||||
}
|
||||
schedule_ptr += 1;
|
||||
}
|
||||
|
||||
if schedule_ptr == schedule.len() {
|
||||
break;
|
||||
}
|
||||
let next_event_time = schedule[schedule_ptr].0;
|
||||
|
||||
// poll until the next event
|
||||
if wait_node.is_finished() {
|
||||
while self.world.step() && self.world.now() < next_event_time {}
|
||||
} else {
|
||||
while self.world.step()
|
||||
&& self.world.now() < next_event_time
|
||||
&& !wait_node.is_finished()
|
||||
{}
|
||||
}
|
||||
}
|
||||
|
||||
debug!("finished schedule");
|
||||
debug!("skipped_tx: {}", skipped_tx);
|
||||
debug!("started_tx: {}", started_tx);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
pub struct WalProposer {
|
||||
pub node: Arc<Node>,
|
||||
}
|
||||
|
||||
impl WalProposer {
|
||||
pub fn write_tx(&mut self, cnt: usize) {
|
||||
self.node
|
||||
.network_chan()
|
||||
.send(NodeEvent::Internal(AnyMessage::Just32(cnt as u32)));
|
||||
}
|
||||
|
||||
pub fn stop(&self) {
|
||||
self.node.crash_stop();
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub enum TestAction {
|
||||
WriteTx(usize),
|
||||
RestartSafekeeper(usize),
|
||||
RestartWalProposer,
|
||||
}
|
||||
|
||||
pub type Schedule = Vec<(u64, TestAction)>;
|
||||
|
||||
pub fn generate_schedule(seed: u64) -> Schedule {
|
||||
let mut rng = rand::rngs::StdRng::seed_from_u64(seed);
|
||||
let mut schedule = Vec::new();
|
||||
let mut time = 0;
|
||||
|
||||
let cnt = rng.gen_range(1..100);
|
||||
|
||||
for _ in 0..cnt {
|
||||
time += rng.gen_range(0..500);
|
||||
let action = match rng.gen_range(0..3) {
|
||||
0 => TestAction::WriteTx(rng.gen_range(1..10)),
|
||||
1 => TestAction::RestartSafekeeper(rng.gen_range(0..3)),
|
||||
2 => TestAction::RestartWalProposer,
|
||||
_ => unreachable!(),
|
||||
};
|
||||
schedule.push((time, action));
|
||||
}
|
||||
|
||||
schedule
|
||||
}
|
||||
|
||||
pub fn generate_network_opts(seed: u64) -> NetworkOptions {
|
||||
let mut rng = rand::rngs::StdRng::seed_from_u64(seed);
|
||||
|
||||
let timeout = rng.gen_range(100..2000);
|
||||
let max_delay = rng.gen_range(1..2*timeout);
|
||||
let min_delay = rng.gen_range(1..=max_delay);
|
||||
|
||||
let max_fail_prob = rng.gen_range(0.0..0.9);
|
||||
let connect_fail_prob = rng.gen_range(0.0..max_fail_prob);
|
||||
let send_fail_prob = rng.gen_range(0.0..connect_fail_prob);
|
||||
|
||||
NetworkOptions {
|
||||
keepalive_timeout: Some(timeout),
|
||||
connect_delay: Delay {
|
||||
min: min_delay,
|
||||
max: max_delay,
|
||||
fail_prob: connect_fail_prob,
|
||||
},
|
||||
send_delay: Delay {
|
||||
min: min_delay,
|
||||
max: max_delay,
|
||||
fail_prob: send_fail_prob,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug,Clone,PartialEq,Eq)]
|
||||
enum NodeKind {
|
||||
Unknown,
|
||||
Safekeeper,
|
||||
WalProposer,
|
||||
}
|
||||
|
||||
impl Default for NodeKind {
|
||||
fn default() -> Self {
|
||||
Self::Unknown
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, Default)]
|
||||
struct NodeInfo {
|
||||
kind: NodeKind,
|
||||
|
||||
// walproposer
|
||||
is_sync: bool,
|
||||
term: u64,
|
||||
epoch_lsn: u64,
|
||||
|
||||
// safekeeper
|
||||
commit_lsn: u64,
|
||||
flush_lsn: u64,
|
||||
}
|
||||
|
||||
impl NodeInfo {
|
||||
fn init_kind(&mut self, kind: NodeKind) {
|
||||
if self.kind == NodeKind::Unknown {
|
||||
self.kind = kind;
|
||||
} else {
|
||||
assert!(self.kind == kind);
|
||||
}
|
||||
}
|
||||
|
||||
fn started(&mut self, data: &str) {
|
||||
let mut parts = data.split(';');
|
||||
assert!(parts.next().unwrap() == "started");
|
||||
match parts.next().unwrap() {
|
||||
"safekeeper" => {
|
||||
self.init_kind(NodeKind::Safekeeper);
|
||||
}
|
||||
"walproposer" => {
|
||||
self.init_kind(NodeKind::WalProposer);
|
||||
let is_sync: u8 = parts.next().unwrap().parse().unwrap();
|
||||
self.is_sync = is_sync != 0;
|
||||
}
|
||||
_ => unreachable!(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug,Default)]
|
||||
struct GlobalState {
|
||||
nodes: Vec<NodeInfo>,
|
||||
commit_lsn: u64,
|
||||
write_lsn: u64,
|
||||
max_write_lsn: u64,
|
||||
|
||||
written_wal: u64,
|
||||
written_records: u64,
|
||||
}
|
||||
|
||||
impl GlobalState {
|
||||
fn new() -> Self {
|
||||
Default::default()
|
||||
}
|
||||
|
||||
fn get(&mut self, id: u32) -> &mut NodeInfo {
|
||||
let id = id as usize;
|
||||
if id >= self.nodes.len() {
|
||||
self.nodes.resize(id + 1, NodeInfo::default());
|
||||
}
|
||||
&mut self.nodes[id]
|
||||
}
|
||||
}
|
||||
|
||||
pub fn validate_events(events: Vec<SEvent>) {
|
||||
const INITDB_LSN: u64 = 21623024;
|
||||
|
||||
let hook = std::panic::take_hook();
|
||||
scopeguard::defer_on_success! {
|
||||
std::panic::set_hook(hook);
|
||||
};
|
||||
|
||||
let mut state = GlobalState::new();
|
||||
state.max_write_lsn = INITDB_LSN;
|
||||
|
||||
for event in events {
|
||||
debug!("{:?}", event);
|
||||
|
||||
let node = state.get(event.node);
|
||||
if event.data.starts_with("started;") {
|
||||
node.started(&event.data);
|
||||
continue;
|
||||
}
|
||||
assert!(node.kind != NodeKind::Unknown);
|
||||
|
||||
// drop reference to unlock state
|
||||
let mut node = node.clone();
|
||||
|
||||
let mut parts = event.data.split(';');
|
||||
match node.kind {
|
||||
NodeKind::Safekeeper => {
|
||||
match parts.next().unwrap() {
|
||||
"tli_loaded" => {
|
||||
let flush_lsn: u64 = parts.next().unwrap().parse().unwrap();
|
||||
let commit_lsn: u64 = parts.next().unwrap().parse().unwrap();
|
||||
node.flush_lsn = flush_lsn;
|
||||
node.commit_lsn = commit_lsn;
|
||||
}
|
||||
_ => unreachable!(),
|
||||
}
|
||||
}
|
||||
NodeKind::WalProposer => {
|
||||
match parts.next().unwrap() {
|
||||
"prop_elected" => {
|
||||
let prop_lsn: u64 = parts.next().unwrap().parse().unwrap();
|
||||
let prop_term: u64 = parts.next().unwrap().parse().unwrap();
|
||||
let prev_lsn: u64 = parts.next().unwrap().parse().unwrap();
|
||||
let prev_term: u64 = parts.next().unwrap().parse().unwrap();
|
||||
|
||||
assert!(prop_lsn >= prev_lsn);
|
||||
assert!(prop_term >= prev_term);
|
||||
|
||||
assert!(prop_lsn >= state.commit_lsn);
|
||||
|
||||
if prop_lsn > state.write_lsn {
|
||||
assert!(prop_lsn <= state.max_write_lsn);
|
||||
debug!("moving write_lsn up from {} to {}", state.write_lsn, prop_lsn);
|
||||
state.write_lsn = prop_lsn;
|
||||
}
|
||||
if prop_lsn < state.write_lsn {
|
||||
debug!("moving write_lsn down from {} to {}", state.write_lsn, prop_lsn);
|
||||
state.write_lsn = prop_lsn;
|
||||
}
|
||||
|
||||
node.epoch_lsn = prop_lsn;
|
||||
node.term = prop_term;
|
||||
}
|
||||
"write_wal" => {
|
||||
assert!(!node.is_sync);
|
||||
let start_lsn: u64 = parts.next().unwrap().parse().unwrap();
|
||||
let end_lsn: u64 = parts.next().unwrap().parse().unwrap();
|
||||
let cnt: u64 = parts.next().unwrap().parse().unwrap();
|
||||
|
||||
let size = end_lsn - start_lsn;
|
||||
state.written_wal += size;
|
||||
state.written_records += cnt;
|
||||
|
||||
// TODO: If we allow writing WAL before winning the election
|
||||
|
||||
assert!(start_lsn >= state.commit_lsn);
|
||||
assert!(end_lsn >= start_lsn);
|
||||
assert!(start_lsn == state.write_lsn);
|
||||
state.write_lsn = end_lsn;
|
||||
|
||||
if end_lsn > state.max_write_lsn {
|
||||
state.max_write_lsn = end_lsn;
|
||||
}
|
||||
}
|
||||
"commit_lsn" => {
|
||||
let lsn: u64 = parts.next().unwrap().parse().unwrap();
|
||||
assert!(lsn >= state.commit_lsn);
|
||||
state.commit_lsn = lsn;
|
||||
}
|
||||
_ => unreachable!(),
|
||||
}
|
||||
}
|
||||
_ => unreachable!(),
|
||||
}
|
||||
|
||||
// update the node in the state struct
|
||||
*state.get(event.node) = node;
|
||||
}
|
||||
}
|
||||
265
libs/walproposer/src/simtest/wp_sk.rs
Normal file
265
libs/walproposer/src/simtest/wp_sk.rs
Normal file
@@ -0,0 +1,265 @@
|
||||
use std::{ffi::CString, path::Path, str::FromStr, sync::Arc};
|
||||
|
||||
use rand::Rng;
|
||||
use safekeeper::simlib::{
|
||||
network::{Delay, NetworkOptions},
|
||||
proto::AnyMessage,
|
||||
world::World,
|
||||
world::{Node, NodeEvent},
|
||||
};
|
||||
use tracing::{info, warn};
|
||||
use utils::{id::TenantTimelineId, lsn::Lsn};
|
||||
|
||||
use crate::{
|
||||
bindings::{
|
||||
neon_tenant_walproposer, neon_timeline_walproposer, sim_redo_start_lsn, syncSafekeepers,
|
||||
wal_acceptor_connection_timeout, wal_acceptor_reconnect_timeout, wal_acceptors_list,
|
||||
MyInsertRecord, WalProposerCleanup, WalProposerRust,
|
||||
},
|
||||
c_context,
|
||||
simtest::{
|
||||
log::{init_logger, SimClock},
|
||||
safekeeper::run_server,
|
||||
util::{generate_schedule, TestConfig, generate_network_opts, validate_events},
|
||||
}, enable_debug,
|
||||
};
|
||||
|
||||
use super::{
|
||||
disk::Disk,
|
||||
util::{Schedule, TestAction},
|
||||
};
|
||||
|
||||
#[test]
|
||||
fn sync_empty_safekeepers() {
|
||||
let clock = init_logger();
|
||||
let mut config = TestConfig::new(Some(clock));
|
||||
let test = config.start(1337);
|
||||
|
||||
let lsn = test.sync_safekeepers().unwrap();
|
||||
assert_eq!(lsn, Lsn(0));
|
||||
info!("Sucessfully synced empty safekeepers at 0/0");
|
||||
|
||||
let lsn = test.sync_safekeepers().unwrap();
|
||||
assert_eq!(lsn, Lsn(0));
|
||||
info!("Sucessfully synced (again) empty safekeepers at 0/0");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn run_walproposer_generate_wal() {
|
||||
let clock = init_logger();
|
||||
let mut config = TestConfig::new(Some(clock));
|
||||
// config.network.timeout = Some(250);
|
||||
let test = config.start(1337);
|
||||
|
||||
let lsn = test.sync_safekeepers().unwrap();
|
||||
assert_eq!(lsn, Lsn(0));
|
||||
info!("Sucessfully synced empty safekeepers at 0/0");
|
||||
|
||||
let mut wp = test.launch_walproposer(lsn);
|
||||
|
||||
test.poll_for_duration(30);
|
||||
|
||||
for i in 0..100 {
|
||||
wp.write_tx(1);
|
||||
test.poll_for_duration(5);
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn crash_safekeeper() {
|
||||
let clock = init_logger();
|
||||
let mut config = TestConfig::new(Some(clock));
|
||||
// config.network.timeout = Some(250);
|
||||
let test = config.start(1337);
|
||||
|
||||
let lsn = test.sync_safekeepers().unwrap();
|
||||
assert_eq!(lsn, Lsn(0));
|
||||
info!("Sucessfully synced empty safekeepers at 0/0");
|
||||
|
||||
let mut wp = test.launch_walproposer(lsn);
|
||||
|
||||
test.poll_for_duration(30);
|
||||
|
||||
wp.write_tx(3);
|
||||
|
||||
test.servers[0].restart();
|
||||
|
||||
test.poll_for_duration(100);
|
||||
test.poll_for_duration(1000);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_simple_restart() {
|
||||
let clock = init_logger();
|
||||
let mut config = TestConfig::new(Some(clock));
|
||||
// config.network.timeout = Some(250);
|
||||
let test = config.start(1337);
|
||||
|
||||
let lsn = test.sync_safekeepers().unwrap();
|
||||
assert_eq!(lsn, Lsn(0));
|
||||
info!("Sucessfully synced empty safekeepers at 0/0");
|
||||
|
||||
let mut wp = test.launch_walproposer(lsn);
|
||||
|
||||
test.poll_for_duration(30);
|
||||
|
||||
wp.write_tx(3);
|
||||
test.poll_for_duration(100);
|
||||
|
||||
wp.stop();
|
||||
drop(wp);
|
||||
|
||||
let lsn = test.sync_safekeepers().unwrap();
|
||||
info!("Sucessfully synced safekeepers at {}", lsn);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_simple_schedule() -> anyhow::Result<()> {
|
||||
let clock = init_logger();
|
||||
let mut config = TestConfig::new(Some(clock));
|
||||
config.network.keepalive_timeout = Some(100);
|
||||
let test = config.start(1337);
|
||||
|
||||
let schedule: Schedule = vec![
|
||||
(0, TestAction::RestartWalProposer),
|
||||
(50, TestAction::WriteTx(5)),
|
||||
(100, TestAction::RestartSafekeeper(0)),
|
||||
(100, TestAction::WriteTx(5)),
|
||||
(110, TestAction::RestartSafekeeper(1)),
|
||||
(110, TestAction::WriteTx(5)),
|
||||
(120, TestAction::RestartSafekeeper(2)),
|
||||
(120, TestAction::WriteTx(5)),
|
||||
(201, TestAction::RestartWalProposer),
|
||||
(251, TestAction::RestartSafekeeper(0)),
|
||||
(251, TestAction::RestartSafekeeper(1)),
|
||||
(251, TestAction::RestartSafekeeper(2)),
|
||||
(251, TestAction::WriteTx(5)),
|
||||
(255, TestAction::WriteTx(5)),
|
||||
(1000, TestAction::WriteTx(5)),
|
||||
];
|
||||
|
||||
test.run_schedule(&schedule)?;
|
||||
info!("Test finished, stopping all threads");
|
||||
test.world.deallocate();
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_many_tx() -> anyhow::Result<()> {
|
||||
enable_debug();
|
||||
let clock = init_logger();
|
||||
let mut config = TestConfig::new(Some(clock));
|
||||
let test = config.start(1337);
|
||||
|
||||
let mut schedule: Schedule = vec![];
|
||||
for i in 0..100 {
|
||||
schedule.push((i * 10, TestAction::WriteTx(10)));
|
||||
}
|
||||
|
||||
test.run_schedule(&schedule)?;
|
||||
info!("Test finished, stopping all threads");
|
||||
test.world.stop_all();
|
||||
|
||||
let events = test.world.take_events();
|
||||
info!("Events: {:?}", events);
|
||||
let last_commit_lsn = events
|
||||
.iter()
|
||||
.filter_map(|event| {
|
||||
if event.data.starts_with("commit_lsn;") {
|
||||
let lsn: u64 = event.data.split(';').nth(1).unwrap().parse().unwrap();
|
||||
return Some(lsn);
|
||||
}
|
||||
None
|
||||
})
|
||||
.last()
|
||||
.unwrap();
|
||||
|
||||
let initdb_lsn = 21623024;
|
||||
let diff = last_commit_lsn - initdb_lsn;
|
||||
info!("Last commit lsn: {}, diff: {}", last_commit_lsn, diff);
|
||||
assert!(diff > 1000 * 8);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_random_schedules() -> anyhow::Result<()> {
|
||||
let clock = init_logger();
|
||||
let mut config = TestConfig::new(Some(clock));
|
||||
config.network.keepalive_timeout = Some(100);
|
||||
|
||||
for i in 0..30000 {
|
||||
let seed: u64 = rand::thread_rng().gen();
|
||||
config.network = generate_network_opts(seed);
|
||||
|
||||
let test = config.start(seed);
|
||||
warn!("Running test with seed {}", seed);
|
||||
|
||||
let schedule = generate_schedule(seed);
|
||||
test.run_schedule(&schedule).unwrap();
|
||||
validate_events(test.world.take_events());
|
||||
test.world.deallocate();
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_one_schedule() -> anyhow::Result<()> {
|
||||
enable_debug();
|
||||
let clock = init_logger();
|
||||
let mut config = TestConfig::new(Some(clock));
|
||||
config.network.keepalive_timeout = Some(100);
|
||||
|
||||
// let seed = 6762900106769428342;
|
||||
// let test = config.start(seed);
|
||||
// warn!("Running test with seed {}", seed);
|
||||
|
||||
// let schedule = generate_schedule(seed);
|
||||
// info!("schedule: {:?}", schedule);
|
||||
// test.run_schedule(&schedule)?;
|
||||
// test.world.deallocate();
|
||||
|
||||
let seed = 3649773280641776194;
|
||||
config.network = generate_network_opts(seed);
|
||||
info!("network: {:?}", config.network);
|
||||
let test = config.start(seed);
|
||||
warn!("Running test with seed {}", seed);
|
||||
|
||||
let schedule = generate_schedule(seed);
|
||||
info!("schedule: {:?}", schedule);
|
||||
test.run_schedule(&schedule).unwrap();
|
||||
validate_events(test.world.take_events());
|
||||
test.world.deallocate();
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_res_dealloc() -> anyhow::Result<()> {
|
||||
// enable_debug();
|
||||
let clock = init_logger();
|
||||
let mut config = TestConfig::new(Some(clock));
|
||||
|
||||
// print pid
|
||||
let pid = unsafe { libc::getpid() };
|
||||
info!("pid: {}", pid);
|
||||
|
||||
let seed = 123456;
|
||||
config.network = generate_network_opts(seed);
|
||||
let test = config.start(seed);
|
||||
warn!("Running test with seed {}", seed);
|
||||
|
||||
let schedule = generate_schedule(seed);
|
||||
info!("schedule: {:?}", schedule);
|
||||
test.run_schedule(&schedule).unwrap();
|
||||
test.world.stop_all();
|
||||
|
||||
let world = test.world.clone();
|
||||
drop(test);
|
||||
info!("world strong count: {}", Arc::strong_count(&world));
|
||||
world.deallocate();
|
||||
info!("world strong count: {}", Arc::strong_count(&world));
|
||||
|
||||
Ok(())
|
||||
}
|
||||
31
libs/walproposer/src/test.rs
Normal file
31
libs/walproposer/src/test.rs
Normal file
@@ -0,0 +1,31 @@
|
||||
use tracing::info;
|
||||
|
||||
use crate::bindings::{TestFunc, MyContextInit};
|
||||
|
||||
#[test]
|
||||
fn test_rust_c_calls() {
|
||||
let res = std::thread::spawn(|| {
|
||||
let res = unsafe {
|
||||
MyContextInit();
|
||||
TestFunc(1, 2)
|
||||
};
|
||||
res
|
||||
}).join().unwrap();
|
||||
info!("res: {}", res);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_sim_bindings() {
|
||||
std::thread::spawn(|| {
|
||||
unsafe {
|
||||
MyContextInit();
|
||||
TestFunc(1, 2)
|
||||
}
|
||||
}).join().unwrap();
|
||||
std::thread::spawn(|| {
|
||||
unsafe {
|
||||
MyContextInit();
|
||||
TestFunc(1, 2)
|
||||
}
|
||||
}).join().unwrap();
|
||||
}
|
||||
100
libs/walproposer/test.c
Normal file
100
libs/walproposer/test.c
Normal file
@@ -0,0 +1,100 @@
|
||||
#include "bindgen_deps.h"
|
||||
#include "rust_bindings.h"
|
||||
#include <stdio.h>
|
||||
#include <pthread.h>
|
||||
#include <stdlib.h>
|
||||
#include "postgres.h"
|
||||
#include "utils/memutils.h"
|
||||
#include "utils/guc.h"
|
||||
#include "miscadmin.h"
|
||||
#include "common/pg_prng.h"
|
||||
|
||||
// From src/backend/main/main.c
|
||||
const char *progname = "fakepostgres";
|
||||
|
||||
int TestFunc(int a, int b) {
|
||||
printf("TestFunc: %d + %d = %d\n", a, b, a + b);
|
||||
rust_function(0);
|
||||
elog(LOG, "postgres elog test");
|
||||
printf("After rust_function\n");
|
||||
return a + b;
|
||||
}
|
||||
|
||||
// This is a quick experiment with rewriting existing Rust code in C.
|
||||
void RunClientC(uint32_t serverId) {
|
||||
uint32_t clientId = sim_id();
|
||||
|
||||
elog(LOG, "started client");
|
||||
|
||||
int data_len = 5;
|
||||
|
||||
int delivered = 0;
|
||||
int tcp = sim_open_tcp(serverId);
|
||||
while (delivered < data_len) {
|
||||
sim_msg_set_repl_cell(delivered+1, clientId, delivered);
|
||||
sim_tcp_send(tcp);
|
||||
|
||||
Event event = sim_epoll_rcv(-1);
|
||||
switch (event.tag)
|
||||
{
|
||||
case Closed:
|
||||
elog(LOG, "connection closed");
|
||||
tcp = sim_open_tcp(serverId);
|
||||
break;
|
||||
|
||||
case Message:
|
||||
Assert(event.any_message == Just32);
|
||||
uint32_t val;
|
||||
sim_msg_get_just_u32(&val);
|
||||
if (val == delivered + 1) {
|
||||
delivered += 1;
|
||||
}
|
||||
break;
|
||||
|
||||
default:
|
||||
Assert(false);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
bool debug_enabled = false;
|
||||
|
||||
bool initializedMemoryContext = false;
|
||||
// pthread_mutex_init(&lock, NULL)?
|
||||
pthread_mutex_t lock;
|
||||
|
||||
void MyContextInit() {
|
||||
// initializes global variables, TODO how to make them thread-local?
|
||||
pthread_mutex_lock(&lock);
|
||||
if (!initializedMemoryContext) {
|
||||
initializedMemoryContext = true;
|
||||
MemoryContextInit();
|
||||
pg_prng_seed(&pg_global_prng_state, 0);
|
||||
|
||||
setenv("PGDATA", "/home/admin/simulator/libs/walproposer/pgdata", 1);
|
||||
|
||||
/*
|
||||
* Set default values for command-line options.
|
||||
*/
|
||||
InitializeGUCOptions();
|
||||
|
||||
/* Acquire configuration parameters */
|
||||
if (!SelectConfigFiles(NULL, progname))
|
||||
exit(1);
|
||||
|
||||
if (debug_enabled) {
|
||||
log_min_messages = LOG;
|
||||
} else {
|
||||
log_min_messages = FATAL;
|
||||
}
|
||||
Log_line_prefix = "[%p] ";
|
||||
|
||||
InitializeMaxBackends();
|
||||
ChangeToDataDir();
|
||||
CreateSharedMemoryAndSemaphores();
|
||||
SetInstallXLogFileSegmentActive();
|
||||
// CreateAuxProcessResourceOwner();
|
||||
// StartupXLOG();
|
||||
}
|
||||
pthread_mutex_unlock(&lock);
|
||||
}
|
||||
@@ -37,7 +37,6 @@ num-traits.workspace = true
|
||||
once_cell.workspace = true
|
||||
pin-project-lite.workspace = true
|
||||
postgres.workspace = true
|
||||
postgres_backend.workspace = true
|
||||
postgres-protocol.workspace = true
|
||||
postgres-types.workspace = true
|
||||
rand.workspace = true
|
||||
|
||||
@@ -23,10 +23,11 @@ use pageserver::{
|
||||
tenant::mgr,
|
||||
virtual_file,
|
||||
};
|
||||
use postgres_backend::AuthType;
|
||||
use utils::{
|
||||
auth::JwtAuth,
|
||||
logging, project_git_version,
|
||||
logging,
|
||||
postgres_backend::AuthType,
|
||||
project_git_version,
|
||||
sentry_init::init_sentry,
|
||||
signals::{self, Signal},
|
||||
tcp_listener,
|
||||
@@ -270,43 +271,43 @@ fn start_pageserver(
|
||||
WALRECEIVER_RUNTIME.block_on(pageserver::broker_client::init_broker_client(conf))?;
|
||||
|
||||
// Initialize authentication for incoming connections
|
||||
let http_auth;
|
||||
let pg_auth;
|
||||
if conf.http_auth_type == AuthType::NeonJWT || conf.pg_auth_type == AuthType::NeonJWT {
|
||||
// unwrap is ok because check is performed when creating config, so path is set and file exists
|
||||
let key_path = conf.auth_validation_public_key_path.as_ref().unwrap();
|
||||
info!(
|
||||
"Loading public key for verifying JWT tokens from {:#?}",
|
||||
key_path
|
||||
);
|
||||
let auth: Arc<JwtAuth> = Arc::new(JwtAuth::from_key_path(key_path)?);
|
||||
let auth = match &conf.auth_type {
|
||||
AuthType::Trust => None,
|
||||
AuthType::NeonJWT => {
|
||||
// unwrap is ok because check is performed when creating config, so path is set and file exists
|
||||
let key_path = conf.auth_validation_public_key_path.as_ref().unwrap();
|
||||
Some(JwtAuth::from_key_path(key_path)?.into())
|
||||
}
|
||||
};
|
||||
info!("Using auth: {:#?}", conf.auth_type);
|
||||
|
||||
http_auth = match &conf.http_auth_type {
|
||||
AuthType::Trust => None,
|
||||
AuthType::NeonJWT => Some(auth.clone()),
|
||||
};
|
||||
pg_auth = match &conf.pg_auth_type {
|
||||
AuthType::Trust => None,
|
||||
AuthType::NeonJWT => Some(auth),
|
||||
};
|
||||
} else {
|
||||
http_auth = None;
|
||||
pg_auth = None;
|
||||
}
|
||||
info!("Using auth for http API: {:#?}", conf.http_auth_type);
|
||||
info!("Using auth for pg connections: {:#?}", conf.pg_auth_type);
|
||||
|
||||
match var("NEON_AUTH_TOKEN") {
|
||||
Ok(v) => {
|
||||
// TODO: remove ZENITH_AUTH_TOKEN once it's not used anywhere in development/staging/prod configuration.
|
||||
match (var("ZENITH_AUTH_TOKEN"), var("NEON_AUTH_TOKEN")) {
|
||||
(old, Ok(v)) => {
|
||||
info!("Loaded JWT token for authentication with Safekeeper");
|
||||
if let Ok(v_old) = old {
|
||||
warn!(
|
||||
"JWT token for Safekeeper is specified twice, ZENITH_AUTH_TOKEN is deprecated"
|
||||
);
|
||||
if v_old != v {
|
||||
warn!("JWT token for Safekeeper has two different values, choosing NEON_AUTH_TOKEN");
|
||||
}
|
||||
}
|
||||
pageserver::config::SAFEKEEPER_AUTH_TOKEN
|
||||
.set(Arc::new(v))
|
||||
.map_err(|_| anyhow!("Could not initialize SAFEKEEPER_AUTH_TOKEN"))?;
|
||||
}
|
||||
Err(VarError::NotPresent) => {
|
||||
(Ok(v), _) => {
|
||||
info!("Loaded JWT token for authentication with Safekeeper");
|
||||
warn!("Please update pageserver configuration: the JWT token should be NEON_AUTH_TOKEN, not ZENITH_AUTH_TOKEN");
|
||||
pageserver::config::SAFEKEEPER_AUTH_TOKEN
|
||||
.set(Arc::new(v))
|
||||
.map_err(|_| anyhow!("Could not initialize SAFEKEEPER_AUTH_TOKEN"))?;
|
||||
}
|
||||
(_, Err(VarError::NotPresent)) => {
|
||||
info!("No JWT token for authentication with Safekeeper detected");
|
||||
}
|
||||
Err(e) => {
|
||||
(_, Err(e)) => {
|
||||
return Err(e).with_context(|| {
|
||||
"Failed to either load to detect non-present NEON_AUTH_TOKEN environment variable"
|
||||
})
|
||||
@@ -324,7 +325,7 @@ fn start_pageserver(
|
||||
{
|
||||
let _rt_guard = MGMT_REQUEST_RUNTIME.enter();
|
||||
|
||||
let router = http::make_router(conf, launch_ts, http_auth, remote_storage)?
|
||||
let router = http::make_router(conf, launch_ts, auth.clone(), remote_storage)?
|
||||
.build()
|
||||
.map_err(|err| anyhow!(err))?;
|
||||
let service = utils::http::RouterService::new(router).unwrap();
|
||||
@@ -398,9 +399,9 @@ fn start_pageserver(
|
||||
async move {
|
||||
page_service::libpq_listener_main(
|
||||
conf,
|
||||
pg_auth,
|
||||
auth,
|
||||
pageserver_listener,
|
||||
conf.pg_auth_type,
|
||||
conf.auth_type,
|
||||
libpq_ctx,
|
||||
)
|
||||
.await
|
||||
|
||||
@@ -21,10 +21,10 @@ use std::time::Duration;
|
||||
use toml_edit;
|
||||
use toml_edit::{Document, Item};
|
||||
|
||||
use postgres_backend::AuthType;
|
||||
use utils::{
|
||||
id::{NodeId, TenantId, TimelineId},
|
||||
logging::LogFormat,
|
||||
postgres_backend::AuthType,
|
||||
};
|
||||
|
||||
use crate::tenant::config::TenantConf;
|
||||
@@ -118,9 +118,6 @@ pub struct PageServerConf {
|
||||
/// Example (default): 127.0.0.1:9898
|
||||
pub listen_http_addr: String,
|
||||
|
||||
/// Current availability zone. Used for traffic metrics.
|
||||
pub availability_zone: Option<String>,
|
||||
|
||||
// Timeout when waiting for WAL receiver to catch up to an LSN given in a GetPage@LSN call.
|
||||
pub wait_lsn_timeout: Duration,
|
||||
// How long to wait for WAL redo to complete.
|
||||
@@ -141,15 +138,9 @@ pub struct PageServerConf {
|
||||
|
||||
pub pg_distrib_dir: PathBuf,
|
||||
|
||||
// Authentication
|
||||
/// authentication method for the HTTP mgmt API
|
||||
pub http_auth_type: AuthType,
|
||||
/// authentication method for libpq connections from compute
|
||||
pub pg_auth_type: AuthType,
|
||||
/// Path to a file containing public key for verifying JWT tokens.
|
||||
/// Used for both mgmt and compute auth, if enabled.
|
||||
pub auth_validation_public_key_path: Option<PathBuf>,
|
||||
pub auth_type: AuthType,
|
||||
|
||||
pub auth_validation_public_key_path: Option<PathBuf>,
|
||||
pub remote_storage_config: Option<RemoteStorageConfig>,
|
||||
|
||||
pub default_tenant_conf: TenantConf,
|
||||
@@ -205,8 +196,6 @@ struct PageServerConfigBuilder {
|
||||
|
||||
listen_http_addr: BuilderValue<String>,
|
||||
|
||||
availability_zone: BuilderValue<Option<String>>,
|
||||
|
||||
wait_lsn_timeout: BuilderValue<Duration>,
|
||||
wal_redo_timeout: BuilderValue<Duration>,
|
||||
|
||||
@@ -219,8 +208,7 @@ struct PageServerConfigBuilder {
|
||||
|
||||
pg_distrib_dir: BuilderValue<PathBuf>,
|
||||
|
||||
http_auth_type: BuilderValue<AuthType>,
|
||||
pg_auth_type: BuilderValue<AuthType>,
|
||||
auth_type: BuilderValue<AuthType>,
|
||||
|
||||
//
|
||||
auth_validation_public_key_path: BuilderValue<Option<PathBuf>>,
|
||||
@@ -252,7 +240,6 @@ impl Default for PageServerConfigBuilder {
|
||||
Self {
|
||||
listen_pg_addr: Set(DEFAULT_PG_LISTEN_ADDR.to_string()),
|
||||
listen_http_addr: Set(DEFAULT_HTTP_LISTEN_ADDR.to_string()),
|
||||
availability_zone: Set(None),
|
||||
wait_lsn_timeout: Set(humantime::parse_duration(DEFAULT_WAIT_LSN_TIMEOUT)
|
||||
.expect("cannot parse default wait lsn timeout")),
|
||||
wal_redo_timeout: Set(humantime::parse_duration(DEFAULT_WAL_REDO_TIMEOUT)
|
||||
@@ -264,8 +251,7 @@ impl Default for PageServerConfigBuilder {
|
||||
pg_distrib_dir: Set(env::current_dir()
|
||||
.expect("cannot access current directory")
|
||||
.join("pg_install")),
|
||||
http_auth_type: Set(AuthType::Trust),
|
||||
pg_auth_type: Set(AuthType::Trust),
|
||||
auth_type: Set(AuthType::Trust),
|
||||
auth_validation_public_key_path: Set(None),
|
||||
remote_storage_config: Set(None),
|
||||
id: NotSet,
|
||||
@@ -309,10 +295,6 @@ impl PageServerConfigBuilder {
|
||||
self.listen_http_addr = BuilderValue::Set(listen_http_addr)
|
||||
}
|
||||
|
||||
pub fn availability_zone(&mut self, availability_zone: Option<String>) {
|
||||
self.availability_zone = BuilderValue::Set(availability_zone)
|
||||
}
|
||||
|
||||
pub fn wait_lsn_timeout(&mut self, wait_lsn_timeout: Duration) {
|
||||
self.wait_lsn_timeout = BuilderValue::Set(wait_lsn_timeout)
|
||||
}
|
||||
@@ -341,12 +323,8 @@ impl PageServerConfigBuilder {
|
||||
self.pg_distrib_dir = BuilderValue::Set(pg_distrib_dir)
|
||||
}
|
||||
|
||||
pub fn http_auth_type(&mut self, auth_type: AuthType) {
|
||||
self.http_auth_type = BuilderValue::Set(auth_type)
|
||||
}
|
||||
|
||||
pub fn pg_auth_type(&mut self, auth_type: AuthType) {
|
||||
self.pg_auth_type = BuilderValue::Set(auth_type)
|
||||
pub fn auth_type(&mut self, auth_type: AuthType) {
|
||||
self.auth_type = BuilderValue::Set(auth_type)
|
||||
}
|
||||
|
||||
pub fn auth_validation_public_key_path(
|
||||
@@ -424,9 +402,6 @@ impl PageServerConfigBuilder {
|
||||
listen_http_addr: self
|
||||
.listen_http_addr
|
||||
.ok_or(anyhow!("missing listen_http_addr"))?,
|
||||
availability_zone: self
|
||||
.availability_zone
|
||||
.ok_or(anyhow!("missing availability_zone"))?,
|
||||
wait_lsn_timeout: self
|
||||
.wait_lsn_timeout
|
||||
.ok_or(anyhow!("missing wait_lsn_timeout"))?,
|
||||
@@ -444,10 +419,7 @@ impl PageServerConfigBuilder {
|
||||
pg_distrib_dir: self
|
||||
.pg_distrib_dir
|
||||
.ok_or(anyhow!("missing pg_distrib_dir"))?,
|
||||
http_auth_type: self
|
||||
.http_auth_type
|
||||
.ok_or(anyhow!("missing http_auth_type"))?,
|
||||
pg_auth_type: self.pg_auth_type.ok_or(anyhow!("missing pg_auth_type"))?,
|
||||
auth_type: self.auth_type.ok_or(anyhow!("missing auth_type"))?,
|
||||
auth_validation_public_key_path: self
|
||||
.auth_validation_public_key_path
|
||||
.ok_or(anyhow!("missing auth_validation_public_key_path"))?,
|
||||
@@ -627,7 +599,6 @@ impl PageServerConf {
|
||||
match key {
|
||||
"listen_pg_addr" => builder.listen_pg_addr(parse_toml_string(key, item)?),
|
||||
"listen_http_addr" => builder.listen_http_addr(parse_toml_string(key, item)?),
|
||||
"availability_zone" => builder.availability_zone(Some(parse_toml_string(key, item)?)),
|
||||
"wait_lsn_timeout" => builder.wait_lsn_timeout(parse_toml_duration(key, item)?),
|
||||
"wal_redo_timeout" => builder.wal_redo_timeout(parse_toml_duration(key, item)?),
|
||||
"initial_superuser_name" => builder.superuser(parse_toml_string(key, item)?),
|
||||
@@ -641,8 +612,7 @@ impl PageServerConf {
|
||||
"auth_validation_public_key_path" => builder.auth_validation_public_key_path(Some(
|
||||
PathBuf::from(parse_toml_string(key, item)?),
|
||||
)),
|
||||
"http_auth_type" => builder.http_auth_type(parse_toml_from_str(key, item)?),
|
||||
"pg_auth_type" => builder.pg_auth_type(parse_toml_from_str(key, item)?),
|
||||
"auth_type" => builder.auth_type(parse_toml_from_str(key, item)?),
|
||||
"remote_storage" => {
|
||||
builder.remote_storage_config(RemoteStorageConfig::from_toml(item)?)
|
||||
}
|
||||
@@ -677,7 +647,7 @@ impl PageServerConf {
|
||||
|
||||
let mut conf = builder.build().context("invalid config")?;
|
||||
|
||||
if conf.http_auth_type == AuthType::NeonJWT || conf.pg_auth_type == AuthType::NeonJWT {
|
||||
if conf.auth_type == AuthType::NeonJWT {
|
||||
let auth_validation_public_key_path = conf
|
||||
.auth_validation_public_key_path
|
||||
.get_or_insert_with(|| workdir.join("auth_public_key.pem"));
|
||||
@@ -728,12 +698,6 @@ impl PageServerConf {
|
||||
Some(parse_toml_u64("compaction_threshold", compaction_threshold)?.try_into()?);
|
||||
}
|
||||
|
||||
if let Some(image_creation_threshold) = item.get("image_creation_threshold") {
|
||||
t_conf.image_creation_threshold = Some(
|
||||
parse_toml_u64("image_creation_threshold", image_creation_threshold)?.try_into()?,
|
||||
);
|
||||
}
|
||||
|
||||
if let Some(gc_horizon) = item.get("gc_horizon") {
|
||||
t_conf.gc_horizon = Some(parse_toml_u64("gc_horizon", gc_horizon)?);
|
||||
}
|
||||
@@ -793,12 +757,10 @@ impl PageServerConf {
|
||||
max_file_descriptors: defaults::DEFAULT_MAX_FILE_DESCRIPTORS,
|
||||
listen_pg_addr: defaults::DEFAULT_PG_LISTEN_ADDR.to_string(),
|
||||
listen_http_addr: defaults::DEFAULT_HTTP_LISTEN_ADDR.to_string(),
|
||||
availability_zone: None,
|
||||
superuser: "cloud_admin".to_string(),
|
||||
workdir: repo_dir,
|
||||
pg_distrib_dir,
|
||||
http_auth_type: AuthType::Trust,
|
||||
pg_auth_type: AuthType::Trust,
|
||||
auth_type: AuthType::Trust,
|
||||
auth_validation_public_key_path: None,
|
||||
remote_storage_config: None,
|
||||
default_tenant_conf: TenantConf::default(),
|
||||
@@ -976,7 +938,6 @@ log_format = 'json'
|
||||
id: NodeId(10),
|
||||
listen_pg_addr: defaults::DEFAULT_PG_LISTEN_ADDR.to_string(),
|
||||
listen_http_addr: defaults::DEFAULT_HTTP_LISTEN_ADDR.to_string(),
|
||||
availability_zone: None,
|
||||
wait_lsn_timeout: humantime::parse_duration(defaults::DEFAULT_WAIT_LSN_TIMEOUT)?,
|
||||
wal_redo_timeout: humantime::parse_duration(defaults::DEFAULT_WAL_REDO_TIMEOUT)?,
|
||||
superuser: defaults::DEFAULT_SUPERUSER.to_string(),
|
||||
@@ -984,8 +945,7 @@ log_format = 'json'
|
||||
max_file_descriptors: defaults::DEFAULT_MAX_FILE_DESCRIPTORS,
|
||||
workdir,
|
||||
pg_distrib_dir,
|
||||
http_auth_type: AuthType::Trust,
|
||||
pg_auth_type: AuthType::Trust,
|
||||
auth_type: AuthType::Trust,
|
||||
auth_validation_public_key_path: None,
|
||||
remote_storage_config: None,
|
||||
default_tenant_conf: TenantConf::default(),
|
||||
@@ -1035,7 +995,6 @@ log_format = 'json'
|
||||
id: NodeId(10),
|
||||
listen_pg_addr: "127.0.0.1:64000".to_string(),
|
||||
listen_http_addr: "127.0.0.1:9898".to_string(),
|
||||
availability_zone: None,
|
||||
wait_lsn_timeout: Duration::from_secs(111),
|
||||
wal_redo_timeout: Duration::from_secs(111),
|
||||
superuser: "zzzz".to_string(),
|
||||
@@ -1043,8 +1002,7 @@ log_format = 'json'
|
||||
max_file_descriptors: 333,
|
||||
workdir,
|
||||
pg_distrib_dir,
|
||||
http_auth_type: AuthType::Trust,
|
||||
pg_auth_type: AuthType::Trust,
|
||||
auth_type: AuthType::Trust,
|
||||
auth_validation_public_key_path: None,
|
||||
remote_storage_config: None,
|
||||
default_tenant_conf: TenantConf::default(),
|
||||
|
||||
@@ -245,53 +245,6 @@ paths:
|
||||
application/json:
|
||||
schema:
|
||||
$ref: "#/components/schemas/Error"
|
||||
/v1/tenant/{tenant_id}/timeline/{timeline_id}/do_gc:
|
||||
parameters:
|
||||
- name: tenant_id
|
||||
in: path
|
||||
required: true
|
||||
schema:
|
||||
type: string
|
||||
format: hex
|
||||
- name: timeline_id
|
||||
in: path
|
||||
required: true
|
||||
schema:
|
||||
type: string
|
||||
format: hex
|
||||
put:
|
||||
description: Garbage collect given timeline
|
||||
responses:
|
||||
"200":
|
||||
description: OK
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
type: string
|
||||
"400":
|
||||
description: Error when no tenant id found in path, no timeline id or invalid timestamp
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
$ref: "#/components/schemas/Error"
|
||||
"401":
|
||||
description: Unauthorized Error
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
$ref: "#/components/schemas/UnauthorizedError"
|
||||
"403":
|
||||
description: Forbidden Error
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
$ref: "#/components/schemas/ForbiddenError"
|
||||
"500":
|
||||
description: Generic operation error
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
$ref: "#/components/schemas/Error"
|
||||
/v1/tenant/{tenant_id}/attach:
|
||||
parameters:
|
||||
- name: tenant_id
|
||||
|
||||
@@ -10,7 +10,6 @@ use remote_storage::GenericRemoteStorage;
|
||||
use tenant_size_model::{SizeResult, StorageModel};
|
||||
use tokio_util::sync::CancellationToken;
|
||||
use tracing::*;
|
||||
use utils::http::endpoint::RequestSpan;
|
||||
use utils::http::request::{get_request_param, must_get_query_param, parse_query_param};
|
||||
|
||||
use super::models::{
|
||||
@@ -21,7 +20,7 @@ use crate::context::{DownloadBehavior, RequestContext};
|
||||
use crate::pgdatadir_mapping::LsnForTimestamp;
|
||||
use crate::task_mgr::TaskKind;
|
||||
use crate::tenant::config::TenantConfOpt;
|
||||
use crate::tenant::mgr::{TenantMapInsertError, TenantStateError};
|
||||
use crate::tenant::mgr::TenantMapInsertError;
|
||||
use crate::tenant::size::ModelInputs;
|
||||
use crate::tenant::storage_layer::LayerAccessStatsReset;
|
||||
use crate::tenant::{PageReconstructError, Timeline};
|
||||
@@ -82,52 +81,38 @@ fn get_config(request: &Request<Body>) -> &'static PageServerConf {
|
||||
get_state(request).conf
|
||||
}
|
||||
|
||||
/// Check that the requester is authorized to operate on given tenant
|
||||
fn check_permission(request: &Request<Body>, tenant_id: Option<TenantId>) -> Result<(), ApiError> {
|
||||
check_permission_with(request, |claims| {
|
||||
crate::auth::check_permission(claims, tenant_id)
|
||||
})
|
||||
}
|
||||
|
||||
impl From<PageReconstructError> for ApiError {
|
||||
fn from(pre: PageReconstructError) -> ApiError {
|
||||
match pre {
|
||||
PageReconstructError::Other(pre) => ApiError::InternalServerError(pre),
|
||||
PageReconstructError::NeedsDownload(_, _) => {
|
||||
// This shouldn't happen, because we use a RequestContext that requests to
|
||||
// download any missing layer files on-demand.
|
||||
ApiError::InternalServerError(anyhow::anyhow!("need to download remote layer file"))
|
||||
}
|
||||
PageReconstructError::Cancelled => {
|
||||
ApiError::InternalServerError(anyhow::anyhow!("request was cancelled"))
|
||||
}
|
||||
PageReconstructError::WalRedo(pre) => {
|
||||
ApiError::InternalServerError(anyhow::Error::new(pre))
|
||||
}
|
||||
fn apierror_from_prerror(err: PageReconstructError) -> ApiError {
|
||||
match err {
|
||||
PageReconstructError::Other(err) => ApiError::InternalServerError(err),
|
||||
PageReconstructError::NeedsDownload(_, _) => {
|
||||
// This shouldn't happen, because we use a RequestContext that requests to
|
||||
// download any missing layer files on-demand.
|
||||
ApiError::InternalServerError(anyhow::anyhow!("need to download remote layer file"))
|
||||
}
|
||||
PageReconstructError::Cancelled => {
|
||||
ApiError::InternalServerError(anyhow::anyhow!("request was cancelled"))
|
||||
}
|
||||
PageReconstructError::WalRedo(err) => {
|
||||
ApiError::InternalServerError(anyhow::Error::new(err))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl From<TenantMapInsertError> for ApiError {
|
||||
fn from(tmie: TenantMapInsertError) -> ApiError {
|
||||
match tmie {
|
||||
TenantMapInsertError::StillInitializing | TenantMapInsertError::ShuttingDown => {
|
||||
ApiError::InternalServerError(anyhow::Error::new(tmie))
|
||||
}
|
||||
TenantMapInsertError::TenantAlreadyExists(id, state) => {
|
||||
ApiError::Conflict(format!("tenant {id} already exists, state: {state:?}"))
|
||||
}
|
||||
TenantMapInsertError::Closure(e) => ApiError::InternalServerError(e),
|
||||
fn apierror_from_tenant_map_insert_error(e: TenantMapInsertError) -> ApiError {
|
||||
match e {
|
||||
TenantMapInsertError::StillInitializing | TenantMapInsertError::ShuttingDown => {
|
||||
ApiError::InternalServerError(anyhow::Error::new(e))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl From<TenantStateError> for ApiError {
|
||||
fn from(tse: TenantStateError) -> ApiError {
|
||||
match tse {
|
||||
TenantStateError::NotFound(tid) => ApiError::NotFound(anyhow!("tenant {}", tid)),
|
||||
_ => ApiError::InternalServerError(anyhow::Error::new(tse)),
|
||||
TenantMapInsertError::TenantAlreadyExists(id, state) => {
|
||||
ApiError::Conflict(format!("tenant {id} already exists, state: {state:?}"))
|
||||
}
|
||||
TenantMapInsertError::Closure(e) => ApiError::InternalServerError(e),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -231,7 +216,9 @@ async fn timeline_create_handler(mut request: Request<Body>) -> Result<Response<
|
||||
|
||||
let ctx = RequestContext::new(TaskKind::MgmtRequest, DownloadBehavior::Error);
|
||||
|
||||
let tenant = mgr::get_tenant(tenant_id, true).await?;
|
||||
let tenant = mgr::get_tenant(tenant_id, true)
|
||||
.await
|
||||
.map_err(ApiError::NotFound)?;
|
||||
match tenant.create_timeline(
|
||||
new_timeline_id,
|
||||
request_data.ancestor_timeline_id.map(TimelineId::from),
|
||||
@@ -261,7 +248,9 @@ async fn timeline_list_handler(request: Request<Body>) -> Result<Response<Body>,
|
||||
let ctx = RequestContext::new(TaskKind::MgmtRequest, DownloadBehavior::Download);
|
||||
|
||||
let response_data = async {
|
||||
let tenant = mgr::get_tenant(tenant_id, true).await?;
|
||||
let tenant = mgr::get_tenant(tenant_id, true)
|
||||
.await
|
||||
.map_err(ApiError::NotFound)?;
|
||||
let timelines = tenant.list_timelines();
|
||||
|
||||
let mut response_data = Vec::with_capacity(timelines.len());
|
||||
@@ -277,7 +266,7 @@ async fn timeline_list_handler(request: Request<Body>) -> Result<Response<Body>,
|
||||
|
||||
response_data.push(timeline_info);
|
||||
}
|
||||
Ok::<Vec<TimelineInfo>, ApiError>(response_data)
|
||||
Ok(response_data)
|
||||
}
|
||||
.instrument(info_span!("timeline_list", tenant = %tenant_id))
|
||||
.await?;
|
||||
@@ -296,7 +285,9 @@ async fn timeline_detail_handler(request: Request<Body>) -> Result<Response<Body
|
||||
let ctx = RequestContext::new(TaskKind::MgmtRequest, DownloadBehavior::Download);
|
||||
|
||||
let timeline_info = async {
|
||||
let tenant = mgr::get_tenant(tenant_id, true).await?;
|
||||
let tenant = mgr::get_tenant(tenant_id, true)
|
||||
.await
|
||||
.map_err(ApiError::NotFound)?;
|
||||
|
||||
let timeline = tenant
|
||||
.get_timeline(timeline_id, false)
|
||||
@@ -332,7 +323,10 @@ async fn get_lsn_by_timestamp_handler(request: Request<Body>) -> Result<Response
|
||||
|
||||
let ctx = RequestContext::new(TaskKind::MgmtRequest, DownloadBehavior::Download);
|
||||
let timeline = active_timeline_of_active_tenant(tenant_id, timeline_id).await?;
|
||||
let result = timeline.find_lsn_for_timestamp(timestamp_pg, &ctx).await?;
|
||||
let result = timeline
|
||||
.find_lsn_for_timestamp(timestamp_pg, &ctx)
|
||||
.await
|
||||
.map_err(apierror_from_prerror)?;
|
||||
|
||||
let result = match result {
|
||||
LsnForTimestamp::Present(lsn) => format!("{lsn}"),
|
||||
@@ -357,7 +351,8 @@ async fn tenant_attach_handler(request: Request<Body>) -> Result<Response<Body>,
|
||||
if let Some(remote_storage) = &state.remote_storage {
|
||||
mgr::attach_tenant(state.conf, tenant_id, remote_storage.clone(), &ctx)
|
||||
.instrument(info_span!("tenant_attach", tenant = %tenant_id))
|
||||
.await?;
|
||||
.await
|
||||
.map_err(apierror_from_tenant_map_insert_error)?;
|
||||
} else {
|
||||
return Err(ApiError::BadRequest(anyhow!(
|
||||
"attach_tenant is not possible because pageserver was configured without remote storage"
|
||||
@@ -376,7 +371,11 @@ async fn timeline_delete_handler(request: Request<Body>) -> Result<Response<Body
|
||||
|
||||
mgr::delete_timeline(tenant_id, timeline_id, &ctx)
|
||||
.instrument(info_span!("timeline_delete", tenant = %tenant_id, timeline = %timeline_id))
|
||||
.await?;
|
||||
.await
|
||||
// FIXME: Errors from `delete_timeline` can occur for a number of reasons, incuding both
|
||||
// user and internal errors. Replace this with better handling once the error type permits
|
||||
// it.
|
||||
.map_err(ApiError::InternalServerError)?;
|
||||
|
||||
json_response(StatusCode::OK, ())
|
||||
}
|
||||
@@ -389,7 +388,10 @@ async fn tenant_detach_handler(request: Request<Body>) -> Result<Response<Body>,
|
||||
let conf = state.conf;
|
||||
mgr::detach_tenant(conf, tenant_id)
|
||||
.instrument(info_span!("tenant_detach", tenant = %tenant_id))
|
||||
.await?;
|
||||
.await
|
||||
// FIXME: Errors from `detach_tenant` can be caused by both both user and internal errors.
|
||||
// Replace this with better handling once the error type permits it.
|
||||
.map_err(ApiError::InternalServerError)?;
|
||||
|
||||
json_response(StatusCode::OK, ())
|
||||
}
|
||||
@@ -403,7 +405,8 @@ async fn tenant_load_handler(request: Request<Body>) -> Result<Response<Body>, A
|
||||
let state = get_state(&request);
|
||||
mgr::load_tenant(state.conf, tenant_id, state.remote_storage.clone(), &ctx)
|
||||
.instrument(info_span!("load", tenant = %tenant_id))
|
||||
.await?;
|
||||
.await
|
||||
.map_err(apierror_from_tenant_map_insert_error)?;
|
||||
|
||||
json_response(StatusCode::ACCEPTED, ())
|
||||
}
|
||||
@@ -416,7 +419,10 @@ async fn tenant_ignore_handler(request: Request<Body>) -> Result<Response<Body>,
|
||||
let conf = state.conf;
|
||||
mgr::ignore_tenant(conf, tenant_id)
|
||||
.instrument(info_span!("ignore_tenant", tenant = %tenant_id))
|
||||
.await?;
|
||||
.await
|
||||
// FIXME: Errors from `ignore_tenant` can be caused by both both user and internal errors.
|
||||
// Replace this with better handling once the error type permits it.
|
||||
.map_err(ApiError::InternalServerError)?;
|
||||
|
||||
json_response(StatusCode::OK, ())
|
||||
}
|
||||
@@ -490,7 +496,9 @@ async fn tenant_size_handler(request: Request<Body>) -> Result<Response<Body>, A
|
||||
let headers = request.headers();
|
||||
|
||||
let ctx = RequestContext::new(TaskKind::MgmtRequest, DownloadBehavior::Download);
|
||||
let tenant = mgr::get_tenant(tenant_id, true).await?;
|
||||
let tenant = mgr::get_tenant(tenant_id, true)
|
||||
.await
|
||||
.map_err(ApiError::InternalServerError)?;
|
||||
|
||||
// this can be long operation
|
||||
let inputs = tenant
|
||||
@@ -753,7 +761,8 @@ async fn tenant_create_handler(mut request: Request<Body>) -> Result<Response<Bo
|
||||
&ctx,
|
||||
)
|
||||
.instrument(info_span!("tenant_create", tenant = ?target_tenant_id))
|
||||
.await?;
|
||||
.await
|
||||
.map_err(apierror_from_tenant_map_insert_error)?;
|
||||
|
||||
// We created the tenant. Existing API semantics are that the tenant
|
||||
// is Active when this function returns.
|
||||
@@ -777,7 +786,9 @@ async fn get_tenant_config_handler(request: Request<Body>) -> Result<Response<Bo
|
||||
let tenant_id: TenantId = parse_request_param(&request, "tenant_id")?;
|
||||
check_permission(&request, Some(tenant_id))?;
|
||||
|
||||
let tenant = mgr::get_tenant(tenant_id, false).await?;
|
||||
let tenant = mgr::get_tenant(tenant_id, false)
|
||||
.await
|
||||
.map_err(ApiError::NotFound)?;
|
||||
|
||||
let response = HashMap::from([
|
||||
(
|
||||
@@ -872,7 +883,10 @@ async fn update_tenant_config_handler(
|
||||
let state = get_state(&request);
|
||||
mgr::set_new_tenant_config(state.conf, tenant_conf, tenant_id)
|
||||
.instrument(info_span!("tenant_config", tenant = ?tenant_id))
|
||||
.await?;
|
||||
.await
|
||||
// FIXME: `update_tenant_config` can fail because of both user and internal errors.
|
||||
// Replace this `map_err` with better error handling once the type permits it
|
||||
.map_err(ApiError::InternalServerError)?;
|
||||
|
||||
json_response(StatusCode::OK, ())
|
||||
}
|
||||
@@ -1009,7 +1023,9 @@ async fn active_timeline_of_active_tenant(
|
||||
tenant_id: TenantId,
|
||||
timeline_id: TimelineId,
|
||||
) -> Result<Arc<Timeline>, ApiError> {
|
||||
let tenant = mgr::get_tenant(tenant_id, true).await?;
|
||||
let tenant = mgr::get_tenant(tenant_id, true)
|
||||
.await
|
||||
.map_err(ApiError::NotFound)?;
|
||||
tenant
|
||||
.get_timeline(timeline_id, true)
|
||||
.map_err(ApiError::NotFound)
|
||||
@@ -1075,8 +1091,7 @@ pub fn make_router(
|
||||
let handler = $handler;
|
||||
#[cfg(not(feature = "testing"))]
|
||||
let handler = cfg_disabled;
|
||||
|
||||
move |r| RequestSpan(handler).handle(r)
|
||||
handler
|
||||
}};
|
||||
}
|
||||
|
||||
@@ -1084,55 +1099,35 @@ pub fn make_router(
|
||||
.data(Arc::new(
|
||||
State::new(conf, auth, remote_storage).context("Failed to initialize router state")?,
|
||||
))
|
||||
.get("/v1/status", |r| RequestSpan(status_handler).handle(r))
|
||||
.get("/v1/status", status_handler)
|
||||
.put(
|
||||
"/v1/failpoints",
|
||||
testing_api!("manage failpoints", failpoints_handler),
|
||||
)
|
||||
.get("/v1/tenant", |r| RequestSpan(tenant_list_handler).handle(r))
|
||||
.post("/v1/tenant", |r| {
|
||||
RequestSpan(tenant_create_handler).handle(r)
|
||||
})
|
||||
.get("/v1/tenant/:tenant_id", |r| {
|
||||
RequestSpan(tenant_status).handle(r)
|
||||
})
|
||||
.get("/v1/tenant/:tenant_id/synthetic_size", |r| {
|
||||
RequestSpan(tenant_size_handler).handle(r)
|
||||
})
|
||||
.put("/v1/tenant/config", |r| {
|
||||
RequestSpan(update_tenant_config_handler).handle(r)
|
||||
})
|
||||
.get("/v1/tenant/:tenant_id/config", |r| {
|
||||
RequestSpan(get_tenant_config_handler).handle(r)
|
||||
})
|
||||
.get("/v1/tenant/:tenant_id/timeline", |r| {
|
||||
RequestSpan(timeline_list_handler).handle(r)
|
||||
})
|
||||
.post("/v1/tenant/:tenant_id/timeline", |r| {
|
||||
RequestSpan(timeline_create_handler).handle(r)
|
||||
})
|
||||
.post("/v1/tenant/:tenant_id/attach", |r| {
|
||||
RequestSpan(tenant_attach_handler).handle(r)
|
||||
})
|
||||
.post("/v1/tenant/:tenant_id/detach", |r| {
|
||||
RequestSpan(tenant_detach_handler).handle(r)
|
||||
})
|
||||
.post("/v1/tenant/:tenant_id/load", |r| {
|
||||
RequestSpan(tenant_load_handler).handle(r)
|
||||
})
|
||||
.post("/v1/tenant/:tenant_id/ignore", |r| {
|
||||
RequestSpan(tenant_ignore_handler).handle(r)
|
||||
})
|
||||
.get("/v1/tenant/:tenant_id/timeline/:timeline_id", |r| {
|
||||
RequestSpan(timeline_detail_handler).handle(r)
|
||||
})
|
||||
.get("/v1/tenant", tenant_list_handler)
|
||||
.post("/v1/tenant", tenant_create_handler)
|
||||
.get("/v1/tenant/:tenant_id", tenant_status)
|
||||
.get("/v1/tenant/:tenant_id/synthetic_size", tenant_size_handler)
|
||||
.put("/v1/tenant/config", update_tenant_config_handler)
|
||||
.get("/v1/tenant/:tenant_id/config", get_tenant_config_handler)
|
||||
.get("/v1/tenant/:tenant_id/timeline", timeline_list_handler)
|
||||
.post("/v1/tenant/:tenant_id/timeline", timeline_create_handler)
|
||||
.post("/v1/tenant/:tenant_id/attach", tenant_attach_handler)
|
||||
.post("/v1/tenant/:tenant_id/detach", tenant_detach_handler)
|
||||
.post("/v1/tenant/:tenant_id/load", tenant_load_handler)
|
||||
.post("/v1/tenant/:tenant_id/ignore", tenant_ignore_handler)
|
||||
.get(
|
||||
"/v1/tenant/:tenant_id/timeline/:timeline_id",
|
||||
timeline_detail_handler,
|
||||
)
|
||||
.get(
|
||||
"/v1/tenant/:tenant_id/timeline/:timeline_id/get_lsn_by_timestamp",
|
||||
|r| RequestSpan(get_lsn_by_timestamp_handler).handle(r),
|
||||
get_lsn_by_timestamp_handler,
|
||||
)
|
||||
.put(
|
||||
"/v1/tenant/:tenant_id/timeline/:timeline_id/do_gc",
|
||||
timeline_gc_handler,
|
||||
)
|
||||
.put("/v1/tenant/:tenant_id/timeline/:timeline_id/do_gc", |r| {
|
||||
RequestSpan(timeline_gc_handler).handle(r)
|
||||
})
|
||||
.put(
|
||||
"/v1/tenant/:tenant_id/timeline/:timeline_id/compact",
|
||||
testing_api!("run timeline compaction", timeline_compact_handler),
|
||||
@@ -1143,26 +1138,28 @@ pub fn make_router(
|
||||
)
|
||||
.post(
|
||||
"/v1/tenant/:tenant_id/timeline/:timeline_id/download_remote_layers",
|
||||
|r| RequestSpan(timeline_download_remote_layers_handler_post).handle(r),
|
||||
timeline_download_remote_layers_handler_post,
|
||||
)
|
||||
.get(
|
||||
"/v1/tenant/:tenant_id/timeline/:timeline_id/download_remote_layers",
|
||||
|r| RequestSpan(timeline_download_remote_layers_handler_get).handle(r),
|
||||
timeline_download_remote_layers_handler_get,
|
||||
)
|
||||
.delete(
|
||||
"/v1/tenant/:tenant_id/timeline/:timeline_id",
|
||||
timeline_delete_handler,
|
||||
)
|
||||
.get(
|
||||
"/v1/tenant/:tenant_id/timeline/:timeline_id/layer",
|
||||
layer_map_info_handler,
|
||||
)
|
||||
.delete("/v1/tenant/:tenant_id/timeline/:timeline_id", |r| {
|
||||
RequestSpan(timeline_delete_handler).handle(r)
|
||||
})
|
||||
.get("/v1/tenant/:tenant_id/timeline/:timeline_id/layer", |r| {
|
||||
RequestSpan(layer_map_info_handler).handle(r)
|
||||
})
|
||||
.get(
|
||||
"/v1/tenant/:tenant_id/timeline/:timeline_id/layer/:layer_file_name",
|
||||
|r| RequestSpan(layer_download_handler).handle(r),
|
||||
layer_download_handler,
|
||||
)
|
||||
.delete(
|
||||
"/v1/tenant/:tenant_id/timeline/:timeline_id/layer/:layer_file_name",
|
||||
|r| RequestSpan(evict_timeline_layer_handler).handle(r),
|
||||
evict_timeline_layer_handler,
|
||||
)
|
||||
.get("/v1/panic", |r| RequestSpan(always_panic_handler).handle(r))
|
||||
.get("/v1/panic", always_panic_handler)
|
||||
.any(handler_404))
|
||||
}
|
||||
|
||||
@@ -123,22 +123,6 @@ static REMOTE_PHYSICAL_SIZE: Lazy<UIntGaugeVec> = Lazy::new(|| {
|
||||
.expect("failed to define a metric")
|
||||
});
|
||||
|
||||
pub static REMOTE_ONDEMAND_DOWNLOADED_LAYERS: Lazy<IntCounter> = Lazy::new(|| {
|
||||
register_int_counter!(
|
||||
"pageserver_remote_ondemand_downloaded_layers_total",
|
||||
"Total on-demand downloaded layers"
|
||||
)
|
||||
.unwrap()
|
||||
});
|
||||
|
||||
pub static REMOTE_ONDEMAND_DOWNLOADED_BYTES: Lazy<IntCounter> = Lazy::new(|| {
|
||||
register_int_counter!(
|
||||
"pageserver_remote_ondemand_downloaded_bytes_total",
|
||||
"Total bytes of layers on-demand downloaded",
|
||||
)
|
||||
.unwrap()
|
||||
});
|
||||
|
||||
static CURRENT_LOGICAL_SIZE: Lazy<UIntGaugeVec> = Lazy::new(|| {
|
||||
register_uint_gauge_vec!(
|
||||
"pageserver_current_logical_size",
|
||||
|
||||
@@ -12,7 +12,7 @@
|
||||
use anyhow::Context;
|
||||
use bytes::Buf;
|
||||
use bytes::Bytes;
|
||||
use futures::Stream;
|
||||
use futures::{Stream, StreamExt};
|
||||
use pageserver_api::models::TenantState;
|
||||
use pageserver_api::models::{
|
||||
PagestreamBeMessage, PagestreamDbSizeRequest, PagestreamDbSizeResponse,
|
||||
@@ -20,9 +20,7 @@ use pageserver_api::models::{
|
||||
PagestreamFeMessage, PagestreamGetPageRequest, PagestreamGetPageResponse,
|
||||
PagestreamNblocksRequest, PagestreamNblocksResponse,
|
||||
};
|
||||
use postgres_backend::PostgresBackendTCP;
|
||||
use postgres_backend::{self, is_expected_io_error, AuthType, PostgresBackend, QueryError};
|
||||
use pq_proto::framed::ConnectionError;
|
||||
use pq_proto::ConnectionError;
|
||||
use pq_proto::FeStartupPacket;
|
||||
use pq_proto::{BeMessage, FeMessage, RowDescriptor};
|
||||
use std::io;
|
||||
@@ -31,13 +29,14 @@ use std::str;
|
||||
use std::str::FromStr;
|
||||
use std::sync::Arc;
|
||||
use std::time::Duration;
|
||||
use tokio_util::io::StreamReader;
|
||||
use tracing::*;
|
||||
use utils::id::ConnectionId;
|
||||
use utils::{
|
||||
auth::{Claims, JwtAuth, Scope},
|
||||
id::{TenantId, TimelineId},
|
||||
lsn::Lsn,
|
||||
postgres_backend::AuthType,
|
||||
postgres_backend_async::{self, is_expected_io_error, PostgresBackend, QueryError},
|
||||
simple_rcu::RcuReadGuard,
|
||||
};
|
||||
|
||||
@@ -56,7 +55,7 @@ use crate::trace::Tracer;
|
||||
use postgres_ffi::pg_constants::DEFAULTTABLESPACE_OID;
|
||||
use postgres_ffi::BLCKSZ;
|
||||
|
||||
fn copyin_stream(pgb: &mut PostgresBackendTCP) -> impl Stream<Item = io::Result<Bytes>> + '_ {
|
||||
fn copyin_stream(pgb: &mut PostgresBackend) -> impl Stream<Item = io::Result<Bytes>> + '_ {
|
||||
async_stream::try_stream! {
|
||||
loop {
|
||||
let msg = tokio::select! {
|
||||
@@ -65,11 +64,11 @@ fn copyin_stream(pgb: &mut PostgresBackendTCP) -> impl Stream<Item = io::Result<
|
||||
_ = task_mgr::shutdown_watcher() => {
|
||||
// We were requested to shut down.
|
||||
let msg = format!("pageserver is shutting down");
|
||||
let _ = pgb.write_message_noflush(&BeMessage::ErrorResponse(&msg, None));
|
||||
let _ = pgb.write_message(&BeMessage::ErrorResponse(&msg, None));
|
||||
Err(QueryError::Other(anyhow::anyhow!(msg)))
|
||||
}
|
||||
|
||||
msg = pgb.read_message() => { msg.map_err(QueryError::from)}
|
||||
msg = pgb.read_message() => { msg }
|
||||
};
|
||||
|
||||
match msg {
|
||||
@@ -80,16 +79,14 @@ fn copyin_stream(pgb: &mut PostgresBackendTCP) -> impl Stream<Item = io::Result<
|
||||
FeMessage::Sync => continue,
|
||||
FeMessage::Terminate => {
|
||||
let msg = "client terminated connection with Terminate message during COPY";
|
||||
let query_error = QueryError::Disconnected(ConnectionError::Io(io::Error::new(io::ErrorKind::ConnectionReset, msg)));
|
||||
// error can't happen here, ErrorResponse serialization should be always ok
|
||||
pgb.write_message_noflush(&BeMessage::ErrorResponse(msg, Some(query_error.pg_error_code()))).map_err(|e| e.into_io_error())?;
|
||||
let query_error_error = QueryError::Disconnected(ConnectionError::Socket(io::Error::new(io::ErrorKind::ConnectionReset, msg)));
|
||||
pgb.write_message(&BeMessage::ErrorResponse(msg, Some(query_error_error.pg_error_code())))?;
|
||||
Err(io::Error::new(io::ErrorKind::ConnectionReset, msg))?;
|
||||
break;
|
||||
}
|
||||
m => {
|
||||
let msg = format!("unexpected message {m:?}");
|
||||
// error can't happen here, ErrorResponse serialization should be always ok
|
||||
pgb.write_message_noflush(&BeMessage::ErrorResponse(&msg, None)).map_err(|e| e.into_io_error())?;
|
||||
pgb.write_message(&BeMessage::ErrorResponse(&msg, None))?;
|
||||
Err(io::Error::new(io::ErrorKind::Other, msg))?;
|
||||
break;
|
||||
}
|
||||
@@ -99,66 +96,22 @@ fn copyin_stream(pgb: &mut PostgresBackendTCP) -> impl Stream<Item = io::Result<
|
||||
}
|
||||
Ok(None) => {
|
||||
let msg = "client closed connection during COPY";
|
||||
let query_error = QueryError::Disconnected(ConnectionError::Io(io::Error::new(io::ErrorKind::ConnectionReset, msg)));
|
||||
// error can't happen here, ErrorResponse serialization should be always ok
|
||||
pgb.write_message_noflush(&BeMessage::ErrorResponse(msg, Some(query_error.pg_error_code()))).map_err(|e| e.into_io_error())?;
|
||||
let query_error_error = QueryError::Disconnected(ConnectionError::Socket(io::Error::new(io::ErrorKind::ConnectionReset, msg)));
|
||||
pgb.write_message(&BeMessage::ErrorResponse(msg, Some(query_error_error.pg_error_code())))?;
|
||||
pgb.flush().await?;
|
||||
Err(io::Error::new(io::ErrorKind::ConnectionReset, msg))?;
|
||||
}
|
||||
Err(QueryError::Disconnected(ConnectionError::Io(io_error))) => {
|
||||
Err(QueryError::Disconnected(ConnectionError::Socket(io_error))) => {
|
||||
Err(io_error)?;
|
||||
}
|
||||
Err(other) => {
|
||||
Err(io::Error::new(io::ErrorKind::Other, other.to_string()))?;
|
||||
Err(io::Error::new(io::ErrorKind::Other, other))?;
|
||||
}
|
||||
};
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Read the end of a tar archive.
|
||||
///
|
||||
/// A tar archive normally ends with two consecutive blocks of zeros, 512 bytes each.
|
||||
/// `tokio_tar` already read the first such block. Read the second all-zeros block,
|
||||
/// and check that there is no more data after the EOF marker.
|
||||
///
|
||||
/// XXX: Currently, any trailing data after the EOF marker prints a warning.
|
||||
/// Perhaps it should be a hard error?
|
||||
async fn read_tar_eof(mut reader: (impl tokio::io::AsyncRead + Unpin)) -> anyhow::Result<()> {
|
||||
use tokio::io::AsyncReadExt;
|
||||
let mut buf = [0u8; 512];
|
||||
|
||||
// Read the all-zeros block, and verify it
|
||||
let mut total_bytes = 0;
|
||||
while total_bytes < 512 {
|
||||
let nbytes = reader.read(&mut buf[total_bytes..]).await?;
|
||||
total_bytes += nbytes;
|
||||
if nbytes == 0 {
|
||||
break;
|
||||
}
|
||||
}
|
||||
if total_bytes < 512 {
|
||||
anyhow::bail!("incomplete or invalid tar EOF marker");
|
||||
}
|
||||
if !buf.iter().all(|&x| x == 0) {
|
||||
anyhow::bail!("invalid tar EOF marker");
|
||||
}
|
||||
|
||||
// Drain any data after the EOF marker
|
||||
let mut trailing_bytes = 0;
|
||||
loop {
|
||||
let nbytes = reader.read(&mut buf).await?;
|
||||
trailing_bytes += nbytes;
|
||||
if nbytes == 0 {
|
||||
break;
|
||||
}
|
||||
}
|
||||
if trailing_bytes > 0 {
|
||||
warn!("ignored {trailing_bytes} unexpected bytes after the tar archive");
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
///
|
||||
@@ -259,7 +212,7 @@ async fn page_service_conn_main(
|
||||
// we've been requested to shut down
|
||||
Ok(())
|
||||
}
|
||||
Err(QueryError::Disconnected(ConnectionError::Io(io_error))) => {
|
||||
Err(QueryError::Disconnected(ConnectionError::Socket(io_error))) => {
|
||||
if is_expected_io_error(&io_error) {
|
||||
info!("Postgres client disconnected ({io_error})");
|
||||
Ok(())
|
||||
@@ -333,7 +286,7 @@ impl PageServerHandler {
|
||||
#[instrument(skip(self, pgb, ctx))]
|
||||
async fn handle_pagerequests(
|
||||
&self,
|
||||
pgb: &mut PostgresBackendTCP,
|
||||
pgb: &mut PostgresBackend,
|
||||
tenant_id: TenantId,
|
||||
timeline_id: TimelineId,
|
||||
ctx: RequestContext,
|
||||
@@ -358,7 +311,7 @@ impl PageServerHandler {
|
||||
let timeline = tenant.get_timeline(timeline_id, true)?;
|
||||
|
||||
// switch client to COPYBOTH
|
||||
pgb.write_message_noflush(&BeMessage::CopyBothResponse)?;
|
||||
pgb.write_message(&BeMessage::CopyBothResponse)?;
|
||||
pgb.flush().await?;
|
||||
|
||||
let metrics = PageRequestMetrics::new(&tenant_id, &timeline_id);
|
||||
@@ -427,7 +380,7 @@ impl PageServerHandler {
|
||||
})
|
||||
});
|
||||
|
||||
pgb.write_message_noflush(&BeMessage::CopyData(&response.serialize()))?;
|
||||
pgb.write_message(&BeMessage::CopyData(&response.serialize()))?;
|
||||
pgb.flush().await?;
|
||||
}
|
||||
Ok(())
|
||||
@@ -437,7 +390,7 @@ impl PageServerHandler {
|
||||
#[instrument(skip(self, pgb, ctx))]
|
||||
async fn handle_import_basebackup(
|
||||
&self,
|
||||
pgb: &mut PostgresBackendTCP,
|
||||
pgb: &mut PostgresBackend,
|
||||
tenant_id: TenantId,
|
||||
timeline_id: TimelineId,
|
||||
base_lsn: Lsn,
|
||||
@@ -463,17 +416,22 @@ impl PageServerHandler {
|
||||
|
||||
// Import basebackup provided via CopyData
|
||||
info!("importing basebackup");
|
||||
pgb.write_message_noflush(&BeMessage::CopyInResponse)?;
|
||||
pgb.write_message(&BeMessage::CopyInResponse)?;
|
||||
pgb.flush().await?;
|
||||
|
||||
let copyin_reader = StreamReader::new(copyin_stream(pgb));
|
||||
tokio::pin!(copyin_reader);
|
||||
let mut copyin_stream = Box::pin(copyin_stream(pgb));
|
||||
timeline
|
||||
.import_basebackup_from_tar(&mut copyin_reader, base_lsn, &ctx)
|
||||
.import_basebackup_from_tar(&mut copyin_stream, base_lsn, &ctx)
|
||||
.await?;
|
||||
|
||||
// Read the end of the tar archive.
|
||||
read_tar_eof(copyin_reader).await?;
|
||||
// Drain the rest of the Copy data
|
||||
let mut bytes_after_tar = 0;
|
||||
while let Some(bytes) = copyin_stream.next().await {
|
||||
bytes_after_tar += bytes?.len();
|
||||
}
|
||||
if bytes_after_tar > 0 {
|
||||
warn!("ignored {bytes_after_tar} unexpected bytes after the tar archive");
|
||||
}
|
||||
|
||||
// TODO check checksum
|
||||
// Meanwhile you can verify client-side by taking fullbackup
|
||||
@@ -488,7 +446,7 @@ impl PageServerHandler {
|
||||
#[instrument(skip(self, pgb, ctx))]
|
||||
async fn handle_import_wal(
|
||||
&self,
|
||||
pgb: &mut PostgresBackendTCP,
|
||||
pgb: &mut PostgresBackend,
|
||||
tenant_id: TenantId,
|
||||
timeline_id: TimelineId,
|
||||
start_lsn: Lsn,
|
||||
@@ -510,15 +468,21 @@ impl PageServerHandler {
|
||||
|
||||
// Import wal provided via CopyData
|
||||
info!("importing wal");
|
||||
pgb.write_message_noflush(&BeMessage::CopyInResponse)?;
|
||||
pgb.write_message(&BeMessage::CopyInResponse)?;
|
||||
pgb.flush().await?;
|
||||
let copyin_reader = StreamReader::new(copyin_stream(pgb));
|
||||
tokio::pin!(copyin_reader);
|
||||
import_wal_from_tar(&timeline, &mut copyin_reader, start_lsn, end_lsn, &ctx).await?;
|
||||
let mut copyin_stream = Box::pin(copyin_stream(pgb));
|
||||
let mut reader = tokio_util::io::StreamReader::new(&mut copyin_stream);
|
||||
import_wal_from_tar(&timeline, &mut reader, start_lsn, end_lsn, &ctx).await?;
|
||||
info!("wal import complete");
|
||||
|
||||
// Read the end of the tar archive.
|
||||
read_tar_eof(copyin_reader).await?;
|
||||
// Drain the rest of the Copy data
|
||||
let mut bytes_after_tar = 0;
|
||||
while let Some(bytes) = copyin_stream.next().await {
|
||||
bytes_after_tar += bytes?.len();
|
||||
}
|
||||
if bytes_after_tar > 0 {
|
||||
warn!("ignored {bytes_after_tar} unexpected bytes after the tar archive");
|
||||
}
|
||||
|
||||
// TODO Does it make sense to overshoot?
|
||||
if timeline.get_last_record_lsn() < end_lsn {
|
||||
@@ -693,7 +657,7 @@ impl PageServerHandler {
|
||||
#[instrument(skip(self, pgb, ctx))]
|
||||
async fn handle_basebackup_request(
|
||||
&mut self,
|
||||
pgb: &mut PostgresBackendTCP,
|
||||
pgb: &mut PostgresBackend,
|
||||
tenant_id: TenantId,
|
||||
timeline_id: TimelineId,
|
||||
lsn: Option<Lsn>,
|
||||
@@ -714,7 +678,7 @@ impl PageServerHandler {
|
||||
}
|
||||
|
||||
// switch client to COPYOUT
|
||||
pgb.write_message_noflush(&BeMessage::CopyOutResponse)?;
|
||||
pgb.write_message(&BeMessage::CopyOutResponse)?;
|
||||
pgb.flush().await?;
|
||||
|
||||
// Send a tarball of the latest layer on the timeline
|
||||
@@ -731,7 +695,7 @@ impl PageServerHandler {
|
||||
.await?;
|
||||
}
|
||||
|
||||
pgb.write_message_noflush(&BeMessage::CopyDone)?;
|
||||
pgb.write_message(&BeMessage::CopyDone)?;
|
||||
pgb.flush().await?;
|
||||
info!("basebackup complete");
|
||||
|
||||
@@ -757,10 +721,10 @@ impl PageServerHandler {
|
||||
}
|
||||
|
||||
#[async_trait::async_trait]
|
||||
impl postgres_backend::Handler<tokio::net::TcpStream> for PageServerHandler {
|
||||
impl postgres_backend_async::Handler for PageServerHandler {
|
||||
fn check_auth_jwt(
|
||||
&mut self,
|
||||
_pgb: &mut PostgresBackendTCP,
|
||||
_pgb: &mut PostgresBackend,
|
||||
jwt_response: &[u8],
|
||||
) -> Result<(), QueryError> {
|
||||
// this unwrap is never triggered, because check_auth_jwt only called when auth_type is NeonJWT
|
||||
@@ -788,7 +752,7 @@ impl postgres_backend::Handler<tokio::net::TcpStream> for PageServerHandler {
|
||||
|
||||
fn startup(
|
||||
&mut self,
|
||||
_pgb: &mut PostgresBackendTCP,
|
||||
_pgb: &mut PostgresBackend,
|
||||
_sm: &FeStartupPacket,
|
||||
) -> Result<(), QueryError> {
|
||||
Ok(())
|
||||
@@ -796,7 +760,7 @@ impl postgres_backend::Handler<tokio::net::TcpStream> for PageServerHandler {
|
||||
|
||||
async fn process_query(
|
||||
&mut self,
|
||||
pgb: &mut PostgresBackendTCP,
|
||||
pgb: &mut PostgresBackend,
|
||||
query_string: &str,
|
||||
) -> Result<(), QueryError> {
|
||||
let ctx = self.connection_ctx.attached_child();
|
||||
@@ -848,7 +812,7 @@ impl postgres_backend::Handler<tokio::net::TcpStream> for PageServerHandler {
|
||||
// Check that the timeline exists
|
||||
self.handle_basebackup_request(pgb, tenant_id, timeline_id, lsn, None, false, ctx)
|
||||
.await?;
|
||||
pgb.write_message_noflush(&BeMessage::CommandComplete(b"SELECT 1"))?;
|
||||
pgb.write_message(&BeMessage::CommandComplete(b"SELECT 1"))?;
|
||||
}
|
||||
// return pair of prev_lsn and last_lsn
|
||||
else if query_string.starts_with("get_last_record_rlsn ") {
|
||||
@@ -871,15 +835,15 @@ impl postgres_backend::Handler<tokio::net::TcpStream> for PageServerHandler {
|
||||
|
||||
let end_of_timeline = timeline.get_last_record_rlsn();
|
||||
|
||||
pgb.write_message_noflush(&BeMessage::RowDescription(&[
|
||||
pgb.write_message(&BeMessage::RowDescription(&[
|
||||
RowDescriptor::text_col(b"prev_lsn"),
|
||||
RowDescriptor::text_col(b"last_lsn"),
|
||||
]))?
|
||||
.write_message_noflush(&BeMessage::DataRow(&[
|
||||
.write_message(&BeMessage::DataRow(&[
|
||||
Some(end_of_timeline.prev.to_string().as_bytes()),
|
||||
Some(end_of_timeline.last.to_string().as_bytes()),
|
||||
]))?
|
||||
.write_message_noflush(&BeMessage::CommandComplete(b"SELECT 1"))?;
|
||||
.write_message(&BeMessage::CommandComplete(b"SELECT 1"))?;
|
||||
}
|
||||
// same as basebackup, but result includes relational data as well
|
||||
else if query_string.starts_with("fullbackup ") {
|
||||
@@ -920,7 +884,7 @@ impl postgres_backend::Handler<tokio::net::TcpStream> for PageServerHandler {
|
||||
// Check that the timeline exists
|
||||
self.handle_basebackup_request(pgb, tenant_id, timeline_id, lsn, prev_lsn, true, ctx)
|
||||
.await?;
|
||||
pgb.write_message_noflush(&BeMessage::CommandComplete(b"SELECT 1"))?;
|
||||
pgb.write_message(&BeMessage::CommandComplete(b"SELECT 1"))?;
|
||||
} else if query_string.starts_with("import basebackup ") {
|
||||
// Import the `base` section (everything but the wal) of a basebackup.
|
||||
// Assumes the tenant already exists on this pageserver.
|
||||
@@ -965,10 +929,10 @@ impl postgres_backend::Handler<tokio::net::TcpStream> for PageServerHandler {
|
||||
)
|
||||
.await
|
||||
{
|
||||
Ok(()) => pgb.write_message_noflush(&BeMessage::CommandComplete(b"SELECT 1"))?,
|
||||
Ok(()) => pgb.write_message(&BeMessage::CommandComplete(b"SELECT 1"))?,
|
||||
Err(e) => {
|
||||
error!("error importing base backup between {base_lsn} and {end_lsn}: {e:?}");
|
||||
pgb.write_message_noflush(&BeMessage::ErrorResponse(
|
||||
pgb.write_message(&BeMessage::ErrorResponse(
|
||||
&e.to_string(),
|
||||
Some(e.pg_error_code()),
|
||||
))?
|
||||
@@ -1001,10 +965,10 @@ impl postgres_backend::Handler<tokio::net::TcpStream> for PageServerHandler {
|
||||
.handle_import_wal(pgb, tenant_id, timeline_id, start_lsn, end_lsn, ctx)
|
||||
.await
|
||||
{
|
||||
Ok(()) => pgb.write_message_noflush(&BeMessage::CommandComplete(b"SELECT 1"))?,
|
||||
Ok(()) => pgb.write_message(&BeMessage::CommandComplete(b"SELECT 1"))?,
|
||||
Err(e) => {
|
||||
error!("error importing WAL between {start_lsn} and {end_lsn}: {e:?}");
|
||||
pgb.write_message_noflush(&BeMessage::ErrorResponse(
|
||||
pgb.write_message(&BeMessage::ErrorResponse(
|
||||
&e.to_string(),
|
||||
Some(e.pg_error_code()),
|
||||
))?
|
||||
@@ -1013,7 +977,7 @@ impl postgres_backend::Handler<tokio::net::TcpStream> for PageServerHandler {
|
||||
} else if query_string.to_ascii_lowercase().starts_with("set ") {
|
||||
// important because psycopg2 executes "SET datestyle TO 'ISO'"
|
||||
// on connect
|
||||
pgb.write_message_noflush(&BeMessage::CommandComplete(b"SELECT 1"))?;
|
||||
pgb.write_message(&BeMessage::CommandComplete(b"SELECT 1"))?;
|
||||
} else if query_string.starts_with("show ") {
|
||||
// show <tenant_id>
|
||||
let (_, params_raw) = query_string.split_at("show ".len());
|
||||
@@ -1029,7 +993,7 @@ impl postgres_backend::Handler<tokio::net::TcpStream> for PageServerHandler {
|
||||
self.check_permission(Some(tenant_id))?;
|
||||
|
||||
let tenant = get_active_tenant_with_timeout(tenant_id, &ctx).await?;
|
||||
pgb.write_message_noflush(&BeMessage::RowDescription(&[
|
||||
pgb.write_message(&BeMessage::RowDescription(&[
|
||||
RowDescriptor::int8_col(b"checkpoint_distance"),
|
||||
RowDescriptor::int8_col(b"checkpoint_timeout"),
|
||||
RowDescriptor::int8_col(b"compaction_target_size"),
|
||||
@@ -1040,7 +1004,7 @@ impl postgres_backend::Handler<tokio::net::TcpStream> for PageServerHandler {
|
||||
RowDescriptor::int8_col(b"image_creation_threshold"),
|
||||
RowDescriptor::int8_col(b"pitr_interval"),
|
||||
]))?
|
||||
.write_message_noflush(&BeMessage::DataRow(&[
|
||||
.write_message(&BeMessage::DataRow(&[
|
||||
Some(tenant.get_checkpoint_distance().to_string().as_bytes()),
|
||||
Some(
|
||||
tenant
|
||||
@@ -1063,7 +1027,7 @@ impl postgres_backend::Handler<tokio::net::TcpStream> for PageServerHandler {
|
||||
Some(tenant.get_image_creation_threshold().to_string().as_bytes()),
|
||||
Some(tenant.get_pitr_interval().as_secs().to_string().as_bytes()),
|
||||
]))?
|
||||
.write_message_noflush(&BeMessage::CommandComplete(b"SELECT 1"))?;
|
||||
.write_message(&BeMessage::CommandComplete(b"SELECT 1"))?;
|
||||
} else {
|
||||
return Err(QueryError::Other(anyhow::anyhow!(
|
||||
"unknown command {query_string}"
|
||||
@@ -1091,7 +1055,7 @@ impl From<GetActiveTenantError> for QueryError {
|
||||
fn from(e: GetActiveTenantError) -> Self {
|
||||
match e {
|
||||
GetActiveTenantError::WaitForActiveTimeout { .. } => QueryError::Disconnected(
|
||||
ConnectionError::Io(io::Error::new(io::ErrorKind::TimedOut, e.to_string())),
|
||||
ConnectionError::Socket(io::Error::new(io::ErrorKind::TimedOut, e.to_string())),
|
||||
),
|
||||
GetActiveTenantError::Other(e) => QueryError::Other(e),
|
||||
}
|
||||
@@ -1107,10 +1071,7 @@ async fn get_active_tenant_with_timeout(
|
||||
tenant_id: TenantId,
|
||||
_ctx: &RequestContext, /* require get a context to support cancellation in the future */
|
||||
) -> Result<Arc<Tenant>, GetActiveTenantError> {
|
||||
let tenant = match mgr::get_tenant(tenant_id, false).await {
|
||||
Ok(tenant) => tenant,
|
||||
Err(e) => return Err(GetActiveTenantError::Other(e.into())),
|
||||
};
|
||||
let tenant = mgr::get_tenant(tenant_id, false).await?;
|
||||
let wait_time = Duration::from_secs(30);
|
||||
match tokio::time::timeout(wait_time, tenant.wait_to_become_active()).await {
|
||||
Ok(Ok(())) => Ok(tenant),
|
||||
|
||||
@@ -12,7 +12,9 @@
|
||||
//!
|
||||
|
||||
use anyhow::{bail, Context};
|
||||
use bytes::Bytes;
|
||||
use futures::FutureExt;
|
||||
use futures::Stream;
|
||||
use pageserver_api::models::TimelineState;
|
||||
use remote_storage::DownloadError;
|
||||
use remote_storage::GenericRemoteStorage;
|
||||
@@ -237,13 +239,14 @@ impl UninitializedTimeline<'_> {
|
||||
/// Prepares timeline data by loading it from the basebackup archive.
|
||||
pub async fn import_basebackup_from_tar(
|
||||
self,
|
||||
copyin_read: &mut (impl tokio::io::AsyncRead + Send + Sync + Unpin),
|
||||
copyin_stream: &mut (impl Stream<Item = io::Result<Bytes>> + Sync + Send + Unpin),
|
||||
base_lsn: Lsn,
|
||||
ctx: &RequestContext,
|
||||
) -> anyhow::Result<Arc<Timeline>> {
|
||||
let raw_timeline = self.raw_timeline()?;
|
||||
|
||||
import_datadir::import_basebackup_from_tar(raw_timeline, copyin_read, base_lsn, ctx)
|
||||
let mut reader = tokio_util::io::StreamReader::new(copyin_stream);
|
||||
import_datadir::import_basebackup_from_tar(raw_timeline, &mut reader, base_lsn, ctx)
|
||||
.await
|
||||
.context("Failed to import basebackup")?;
|
||||
|
||||
@@ -1240,8 +1243,11 @@ impl Tenant {
|
||||
"Cannot run GC iteration on inactive tenant"
|
||||
);
|
||||
|
||||
self.gc_iteration_internal(target_timeline_id, horizon, pitr, ctx)
|
||||
.await
|
||||
let gc_result = self
|
||||
.gc_iteration_internal(target_timeline_id, horizon, pitr, ctx)
|
||||
.await;
|
||||
|
||||
gc_result
|
||||
}
|
||||
|
||||
/// Perform one compaction iteration.
|
||||
@@ -3170,44 +3176,6 @@ mod tests {
|
||||
}
|
||||
*/
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_get_branchpoints_from_an_inactive_timeline() -> anyhow::Result<()> {
|
||||
let (tenant, ctx) =
|
||||
TenantHarness::create("test_get_branchpoints_from_an_inactive_timeline")?
|
||||
.load()
|
||||
.await;
|
||||
let tline = tenant
|
||||
.create_empty_timeline(TIMELINE_ID, Lsn(0), DEFAULT_PG_VERSION, &ctx)?
|
||||
.initialize(&ctx)?;
|
||||
make_some_layers(tline.as_ref(), Lsn(0x20)).await?;
|
||||
|
||||
tenant
|
||||
.branch_timeline(&tline, NEW_TIMELINE_ID, Some(Lsn(0x40)), &ctx)
|
||||
.await?;
|
||||
let newtline = tenant
|
||||
.get_timeline(NEW_TIMELINE_ID, true)
|
||||
.expect("Should have a local timeline");
|
||||
|
||||
make_some_layers(newtline.as_ref(), Lsn(0x60)).await?;
|
||||
|
||||
tline.set_state(TimelineState::Broken);
|
||||
|
||||
tenant
|
||||
.gc_iteration(Some(TIMELINE_ID), 0x10, Duration::ZERO, &ctx)
|
||||
.await?;
|
||||
|
||||
assert_eq!(
|
||||
newtline.get(*TEST_KEY, Lsn(0x50), &ctx).await?,
|
||||
TEST_IMG(&format!("foo at {}", Lsn(0x40)))
|
||||
);
|
||||
|
||||
let branchpoints = &tline.gc_info.read().unwrap().retain_lsns;
|
||||
assert_eq!(branchpoints.len(), 1);
|
||||
assert_eq!(branchpoints[0], Lsn(0x40));
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_retain_data_in_parent_which_is_needed_for_child() -> anyhow::Result<()> {
|
||||
let (tenant, ctx) =
|
||||
|
||||
@@ -51,6 +51,9 @@ where
|
||||
///
|
||||
/// A "cursor" for efficiently reading multiple pages from a BlockReader
|
||||
///
|
||||
/// A cursor caches the last accessed page, allowing for faster access if the
|
||||
/// same block is accessed repeatedly.
|
||||
///
|
||||
/// You can access the last page with `*cursor`. 'read_blk' returns 'self', so
|
||||
/// that in many cases you can use a BlockCursor as a drop-in replacement for
|
||||
/// the underlying BlockReader. For example:
|
||||
@@ -70,6 +73,8 @@ where
|
||||
R: BlockReader,
|
||||
{
|
||||
reader: R,
|
||||
/// last accessed page
|
||||
cache: Option<(u32, R::BlockLease)>,
|
||||
}
|
||||
|
||||
impl<R> BlockCursor<R>
|
||||
@@ -77,13 +82,40 @@ where
|
||||
R: BlockReader,
|
||||
{
|
||||
pub fn new(reader: R) -> Self {
|
||||
BlockCursor { reader }
|
||||
BlockCursor {
|
||||
reader,
|
||||
cache: None,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn read_blk(&mut self, blknum: u32) -> Result<R::BlockLease, std::io::Error> {
|
||||
self.reader.read_blk(blknum)
|
||||
pub fn read_blk(&mut self, blknum: u32) -> Result<&Self, std::io::Error> {
|
||||
// Fast return if this is the same block as before
|
||||
if let Some((cached_blk, _buf)) = &self.cache {
|
||||
if *cached_blk == blknum {
|
||||
return Ok(self);
|
||||
}
|
||||
}
|
||||
|
||||
// Read the block from the underlying reader, and cache it
|
||||
self.cache = None;
|
||||
let buf = self.reader.read_blk(blknum)?;
|
||||
self.cache = Some((blknum, buf));
|
||||
|
||||
Ok(self)
|
||||
}
|
||||
}
|
||||
|
||||
impl<R> Deref for BlockCursor<R>
|
||||
where
|
||||
R: BlockReader,
|
||||
{
|
||||
type Target = [u8; PAGE_SZ];
|
||||
|
||||
fn deref(&self) -> &<Self as Deref>::Target {
|
||||
&self.cache.as_ref().unwrap().1
|
||||
}
|
||||
}
|
||||
|
||||
static NEXT_ID: AtomicU64 = AtomicU64::new(1);
|
||||
|
||||
/// An adapter for reading a (virtual) file using the page cache.
|
||||
|
||||
@@ -222,6 +222,48 @@ impl TenantConfOpt {
|
||||
eviction_policy: self.eviction_policy.unwrap_or(global_conf.eviction_policy),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn update(&mut self, other: &TenantConfOpt) {
|
||||
if let Some(checkpoint_distance) = other.checkpoint_distance {
|
||||
self.checkpoint_distance = Some(checkpoint_distance);
|
||||
}
|
||||
if let Some(checkpoint_timeout) = other.checkpoint_timeout {
|
||||
self.checkpoint_timeout = Some(checkpoint_timeout);
|
||||
}
|
||||
if let Some(compaction_target_size) = other.compaction_target_size {
|
||||
self.compaction_target_size = Some(compaction_target_size);
|
||||
}
|
||||
if let Some(compaction_period) = other.compaction_period {
|
||||
self.compaction_period = Some(compaction_period);
|
||||
}
|
||||
if let Some(compaction_threshold) = other.compaction_threshold {
|
||||
self.compaction_threshold = Some(compaction_threshold);
|
||||
}
|
||||
if let Some(gc_horizon) = other.gc_horizon {
|
||||
self.gc_horizon = Some(gc_horizon);
|
||||
}
|
||||
if let Some(gc_period) = other.gc_period {
|
||||
self.gc_period = Some(gc_period);
|
||||
}
|
||||
if let Some(image_creation_threshold) = other.image_creation_threshold {
|
||||
self.image_creation_threshold = Some(image_creation_threshold);
|
||||
}
|
||||
if let Some(pitr_interval) = other.pitr_interval {
|
||||
self.pitr_interval = Some(pitr_interval);
|
||||
}
|
||||
if let Some(walreceiver_connect_timeout) = other.walreceiver_connect_timeout {
|
||||
self.walreceiver_connect_timeout = Some(walreceiver_connect_timeout);
|
||||
}
|
||||
if let Some(lagging_wal_timeout) = other.lagging_wal_timeout {
|
||||
self.lagging_wal_timeout = Some(lagging_wal_timeout);
|
||||
}
|
||||
if let Some(max_lsn_wal_lag) = other.max_lsn_wal_lag {
|
||||
self.max_lsn_wal_lag = Some(max_lsn_wal_lag);
|
||||
}
|
||||
if let Some(trace_read_requests) = other.trace_read_requests {
|
||||
self.trace_read_requests = Some(trace_read_requests);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for TenantConf {
|
||||
|
||||
@@ -2,7 +2,9 @@
|
||||
//! used to keep in-memory layers spilled on disk.
|
||||
|
||||
use crate::config::PageServerConf;
|
||||
use crate::page_cache::{self, ReadBufResult, WriteBufResult, PAGE_SZ};
|
||||
use crate::page_cache;
|
||||
use crate::page_cache::PAGE_SZ;
|
||||
use crate::page_cache::{ReadBufResult, WriteBufResult};
|
||||
use crate::tenant::blob_io::BlobWriter;
|
||||
use crate::tenant::block_io::BlockReader;
|
||||
use crate::virtual_file::VirtualFile;
|
||||
@@ -425,6 +427,7 @@ mod tests {
|
||||
let actual = cursor.read_blob(pos)?;
|
||||
assert_eq!(actual, expected);
|
||||
}
|
||||
drop(cursor);
|
||||
|
||||
// Test a large blob that spans multiple pages
|
||||
let mut large_data = Vec::new();
|
||||
|
||||
@@ -154,7 +154,11 @@ where
|
||||
expected: &Arc<L>,
|
||||
new: Arc<L>,
|
||||
) -> anyhow::Result<Replacement<Arc<L>>> {
|
||||
fail::fail_point!("layermap-replace-notfound", |_| Ok(Replacement::NotFound));
|
||||
fail::fail_point!("layermap-replace-notfound", |_| Ok(
|
||||
// this is not what happens if an L0 layer was not found a anyhow error but perhaps
|
||||
// that should be changed. this is good enough to show a replacement failure.
|
||||
Replacement::NotFound
|
||||
));
|
||||
|
||||
self.layer_map.replace_historic_noflush(expected, new)
|
||||
}
|
||||
@@ -336,15 +340,12 @@ where
|
||||
|
||||
let l0_index = if expected_l0 {
|
||||
// find the index in case replace worked, we need to replace that as well
|
||||
let pos = self
|
||||
.l0_delta_layers
|
||||
.iter()
|
||||
.position(|slot| Self::compare_arced_layers(slot, expected));
|
||||
|
||||
if pos.is_none() {
|
||||
return Ok(Replacement::NotFound);
|
||||
}
|
||||
pos
|
||||
Some(
|
||||
self.l0_delta_layers
|
||||
.iter()
|
||||
.position(|slot| Self::compare_arced_layers(slot, expected))
|
||||
.ok_or_else(|| anyhow::anyhow!("existing l0 delta layer was not found"))?,
|
||||
)
|
||||
} else {
|
||||
None
|
||||
};
|
||||
@@ -803,26 +804,6 @@ mod tests {
|
||||
)
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn replacing_missing_l0_is_notfound() {
|
||||
// original impl had an oversight, and L0 was an anyhow::Error. anyhow::Error should
|
||||
// however only happen for precondition failures.
|
||||
|
||||
let layer = "000000000000000000000000000000000000-FFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFF__0000000053423C21-0000000053424D69";
|
||||
let layer = LayerFileName::from_str(layer).unwrap();
|
||||
let layer = LayerDescriptor::from(layer);
|
||||
|
||||
// same skeletan construction; see scenario below
|
||||
let not_found: Arc<dyn Layer> = Arc::new(layer.clone());
|
||||
let new_version: Arc<dyn Layer> = Arc::new(layer);
|
||||
|
||||
let mut map = LayerMap::default();
|
||||
|
||||
let res = map.batch_update().replace_historic(¬_found, new_version);
|
||||
|
||||
assert!(matches!(res, Ok(Replacement::NotFound)), "{res:?}");
|
||||
}
|
||||
|
||||
fn l0_delta_layers_updated_scenario(layer_name: &str, expected_l0: bool) {
|
||||
let name = LayerFileName::from_str(layer_name).unwrap();
|
||||
let skeleton = LayerDescriptor::from(name);
|
||||
@@ -832,8 +813,7 @@ mod tests {
|
||||
|
||||
let mut map = LayerMap::default();
|
||||
|
||||
// two disjoint Arcs in different lifecycle phases. even if it seems they must be the
|
||||
// same layer, we use LayerMap::compare_arced_layers as the identity of layers.
|
||||
// two disjoint Arcs in different lifecycle phases.
|
||||
assert!(!LayerMap::compare_arced_layers(&remote, &downloaded));
|
||||
|
||||
let expected_in_counts = (1, usize::from(expected_l0));
|
||||
|
||||
@@ -289,7 +289,7 @@ pub async fn set_new_tenant_config(
|
||||
conf: &'static PageServerConf,
|
||||
new_tenant_conf: TenantConfOpt,
|
||||
tenant_id: TenantId,
|
||||
) -> Result<(), TenantStateError> {
|
||||
) -> anyhow::Result<()> {
|
||||
info!("configuring tenant {tenant_id}");
|
||||
let tenant = get_tenant(tenant_id, true).await?;
|
||||
|
||||
@@ -306,20 +306,16 @@ pub async fn set_new_tenant_config(
|
||||
|
||||
/// Gets the tenant from the in-memory data, erroring if it's absent or is not fitting to the query.
|
||||
/// `active_only = true` allows to query only tenants that are ready for operations, erroring on other kinds of tenants.
|
||||
pub async fn get_tenant(
|
||||
tenant_id: TenantId,
|
||||
active_only: bool,
|
||||
) -> Result<Arc<Tenant>, TenantStateError> {
|
||||
pub async fn get_tenant(tenant_id: TenantId, active_only: bool) -> anyhow::Result<Arc<Tenant>> {
|
||||
let m = TENANTS.read().await;
|
||||
let tenant = m
|
||||
.get(&tenant_id)
|
||||
.ok_or(TenantStateError::NotFound(tenant_id))?;
|
||||
.with_context(|| format!("Tenant {tenant_id} not found in the local state"))?;
|
||||
if active_only && !tenant.is_active() {
|
||||
tracing::warn!(
|
||||
anyhow::bail!(
|
||||
"Tenant {tenant_id} is not active. Current state: {:?}",
|
||||
tenant.current_state()
|
||||
);
|
||||
Err(TenantStateError::NotActive(tenant_id))
|
||||
)
|
||||
} else {
|
||||
Ok(Arc::clone(tenant))
|
||||
}
|
||||
@@ -329,28 +325,21 @@ pub async fn delete_timeline(
|
||||
tenant_id: TenantId,
|
||||
timeline_id: TimelineId,
|
||||
ctx: &RequestContext,
|
||||
) -> Result<(), TenantStateError> {
|
||||
let tenant = get_tenant(tenant_id, true).await?;
|
||||
tenant.delete_timeline(timeline_id, ctx).await?;
|
||||
Ok(())
|
||||
}
|
||||
) -> anyhow::Result<()> {
|
||||
match get_tenant(tenant_id, true).await {
|
||||
Ok(tenant) => {
|
||||
tenant.delete_timeline(timeline_id, ctx).await?;
|
||||
}
|
||||
Err(e) => anyhow::bail!("Cannot access tenant {tenant_id} in local tenant state: {e:?}"),
|
||||
}
|
||||
|
||||
#[derive(Debug, thiserror::Error)]
|
||||
pub enum TenantStateError {
|
||||
#[error("Tenant {0} not found")]
|
||||
NotFound(TenantId),
|
||||
#[error("Tenant {0} is stopping")]
|
||||
IsStopping(TenantId),
|
||||
#[error("Tenant {0} is not active")]
|
||||
NotActive(TenantId),
|
||||
#[error(transparent)]
|
||||
Other(#[from] anyhow::Error),
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub async fn detach_tenant(
|
||||
conf: &'static PageServerConf,
|
||||
tenant_id: TenantId,
|
||||
) -> Result<(), TenantStateError> {
|
||||
) -> anyhow::Result<()> {
|
||||
remove_tenant_from_memory(tenant_id, async {
|
||||
let local_tenant_directory = conf.tenant_path(&tenant_id);
|
||||
fs::remove_dir_all(&local_tenant_directory)
|
||||
@@ -390,7 +379,7 @@ pub async fn load_tenant(
|
||||
pub async fn ignore_tenant(
|
||||
conf: &'static PageServerConf,
|
||||
tenant_id: TenantId,
|
||||
) -> Result<(), TenantStateError> {
|
||||
) -> anyhow::Result<()> {
|
||||
remove_tenant_from_memory(tenant_id, async {
|
||||
let ignore_mark_file = conf.tenant_ignore_mark_file_path(tenant_id);
|
||||
fs::File::create(&ignore_mark_file)
|
||||
@@ -500,7 +489,7 @@ where
|
||||
async fn remove_tenant_from_memory<V, F>(
|
||||
tenant_id: TenantId,
|
||||
tenant_cleanup: F,
|
||||
) -> Result<V, TenantStateError>
|
||||
) -> anyhow::Result<V>
|
||||
where
|
||||
F: std::future::Future<Output = anyhow::Result<V>>,
|
||||
{
|
||||
@@ -516,9 +505,11 @@ where
|
||||
| TenantState::Loading
|
||||
| TenantState::Broken
|
||||
| TenantState::Active => tenant.set_stopping(),
|
||||
TenantState::Stopping => return Err(TenantStateError::IsStopping(tenant_id)),
|
||||
TenantState::Stopping => {
|
||||
anyhow::bail!("Tenant {tenant_id} is stopping already")
|
||||
}
|
||||
},
|
||||
None => return Err(TenantStateError::NotFound(tenant_id)),
|
||||
None => anyhow::bail!("Tenant not found for id {tenant_id}"),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -541,15 +532,10 @@ where
|
||||
Err(e) => {
|
||||
let tenants_accessor = TENANTS.read().await;
|
||||
match tenants_accessor.get(&tenant_id) {
|
||||
Some(tenant) => {
|
||||
tenant.set_broken(&e.to_string());
|
||||
}
|
||||
None => {
|
||||
warn!("Tenant {tenant_id} got removed from memory");
|
||||
return Err(TenantStateError::NotFound(tenant_id));
|
||||
}
|
||||
Some(tenant) => tenant.set_broken(&e.to_string()),
|
||||
None => warn!("Tenant {tenant_id} got removed from memory"),
|
||||
}
|
||||
Err(TenantStateError::Other(e))
|
||||
Err(e)
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -569,7 +555,7 @@ pub async fn immediate_gc(
|
||||
let tenant = guard
|
||||
.get(&tenant_id)
|
||||
.map(Arc::clone)
|
||||
.with_context(|| format!("tenant {tenant_id}"))
|
||||
.with_context(|| format!("Tenant {tenant_id} not found"))
|
||||
.map_err(ApiError::NotFound)?;
|
||||
|
||||
let gc_horizon = gc_req.gc_horizon.unwrap_or_else(|| tenant.get_gc_horizon());
|
||||
@@ -619,7 +605,7 @@ pub async fn immediate_compact(
|
||||
let tenant = guard
|
||||
.get(&tenant_id)
|
||||
.map(Arc::clone)
|
||||
.with_context(|| format!("tenant {tenant_id}"))
|
||||
.with_context(|| format!("Tenant {tenant_id} not found"))
|
||||
.map_err(ApiError::NotFound)?;
|
||||
|
||||
let timeline = tenant
|
||||
|
||||
@@ -218,10 +218,9 @@ use tracing::{debug, info, warn};
|
||||
use tracing::{info_span, Instrument};
|
||||
use utils::lsn::Lsn;
|
||||
|
||||
use crate::metrics::{
|
||||
MeasureRemoteOp, RemoteOpFileKind, RemoteOpKind, RemoteTimelineClientMetrics,
|
||||
REMOTE_ONDEMAND_DOWNLOADED_BYTES, REMOTE_ONDEMAND_DOWNLOADED_LAYERS,
|
||||
};
|
||||
use crate::metrics::RemoteOpFileKind;
|
||||
use crate::metrics::RemoteOpKind;
|
||||
use crate::metrics::{MeasureRemoteOp, RemoteTimelineClientMetrics};
|
||||
use crate::tenant::remote_timeline_client::index::LayerFileMetadata;
|
||||
use crate::{
|
||||
config::PageServerConf,
|
||||
@@ -447,10 +446,6 @@ impl RemoteTimelineClient {
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
REMOTE_ONDEMAND_DOWNLOADED_LAYERS.inc();
|
||||
REMOTE_ONDEMAND_DOWNLOADED_BYTES.inc_by(downloaded_size);
|
||||
|
||||
Ok(downloaded_size)
|
||||
}
|
||||
|
||||
|
||||
@@ -6,13 +6,11 @@
|
||||
use std::collections::HashSet;
|
||||
use std::future::Future;
|
||||
use std::path::Path;
|
||||
use std::time::Duration;
|
||||
|
||||
use anyhow::{anyhow, Context};
|
||||
use tokio::fs;
|
||||
use tokio::io::AsyncWriteExt;
|
||||
|
||||
use tracing::{info, warn};
|
||||
use tracing::{error, info, warn};
|
||||
|
||||
use crate::config::PageServerConf;
|
||||
use crate::tenant::storage_layer::LayerFileName;
|
||||
@@ -28,8 +26,6 @@ async fn fsync_path(path: impl AsRef<std::path::Path>) -> Result<(), std::io::Er
|
||||
fs::File::open(path).await?.sync_all().await
|
||||
}
|
||||
|
||||
static MAX_DOWNLOAD_DURATION: Duration = Duration::from_secs(120);
|
||||
|
||||
///
|
||||
/// If 'metadata' is given, we will validate that the downloaded file's size matches that
|
||||
/// in the metadata. (In the future, we might do more cross-checks, like CRC validation)
|
||||
@@ -68,28 +64,22 @@ pub async fn download_layer_file<'a>(
|
||||
// TODO: this doesn't use the cached fd for some reason?
|
||||
let mut destination_file = fs::File::create(&temp_file_path).await.with_context(|| {
|
||||
format!(
|
||||
"create a destination file for layer '{}'",
|
||||
"Failed to create a destination file for layer '{}'",
|
||||
temp_file_path.display()
|
||||
)
|
||||
})
|
||||
.map_err(DownloadError::Other)?;
|
||||
let mut download = storage.download(&remote_path).await.with_context(|| {
|
||||
format!(
|
||||
"open a download stream for layer with remote storage path '{remote_path:?}'"
|
||||
"Failed to open a download stream for layer with remote storage path '{remote_path:?}'"
|
||||
)
|
||||
})
|
||||
.map_err(DownloadError::Other)?;
|
||||
|
||||
let bytes_amount = tokio::time::timeout(MAX_DOWNLOAD_DURATION, tokio::io::copy(&mut download.download_stream, &mut destination_file))
|
||||
.await
|
||||
.map_err(|e| DownloadError::Other(anyhow::anyhow!("Timed out {:?}", e)))?
|
||||
.with_context(|| {
|
||||
format!("Failed to download layer with remote storage path '{remote_path:?}' into file {temp_file_path:?}")
|
||||
})
|
||||
.map_err(DownloadError::Other)?;
|
||||
|
||||
let bytes_amount = tokio::io::copy(&mut download.download_stream, &mut destination_file).await.with_context(|| {
|
||||
format!("Failed to download layer with remote storage path '{remote_path:?}' into file {temp_file_path:?}")
|
||||
})
|
||||
.map_err(DownloadError::Other)?;
|
||||
Ok((destination_file, bytes_amount))
|
||||
|
||||
},
|
||||
&format!("download {remote_path:?}"),
|
||||
).await?;
|
||||
@@ -310,7 +300,7 @@ where
|
||||
}
|
||||
Err(DownloadError::Other(ref err)) => {
|
||||
// Operation failed FAILED_DOWNLOAD_RETRIES times. Time to give up.
|
||||
warn!("{description} still failed after {attempts} retries, giving up: {err:?}");
|
||||
error!("{description} still failed after {attempts} retries, giving up: {err:?}");
|
||||
return result;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1246,11 +1246,6 @@ impl Timeline {
|
||||
|
||||
state,
|
||||
};
|
||||
info!(
|
||||
"initialized lsrlsn to {}/{}",
|
||||
disk_consistent_lsn,
|
||||
metadata.prev_record_lsn().unwrap_or(Lsn(0))
|
||||
);
|
||||
result.repartition_threshold = result.get_checkpoint_distance() / 10;
|
||||
result
|
||||
.metrics
|
||||
@@ -1340,7 +1335,6 @@ impl Timeline {
|
||||
lagging_wal_timeout,
|
||||
max_lsn_wal_lag,
|
||||
crate::config::SAFEKEEPER_AUTH_TOKEN.get().cloned(),
|
||||
self.conf.availability_zone.clone(),
|
||||
background_ctx,
|
||||
);
|
||||
}
|
||||
@@ -2721,22 +2715,10 @@ impl Timeline {
|
||||
) -> Result<HashMap<LayerFileName, LayerFileMetadata>, PageReconstructError> {
|
||||
let timer = self.metrics.create_images_time_histo.start_timer();
|
||||
let mut image_layers: Vec<ImageLayer> = Vec::new();
|
||||
|
||||
// We need to avoid holes between generated image layers.
|
||||
// Otherwise LayerMap::image_layer_exists will return false if key range of some layer is covered by more than one
|
||||
// image layer with hole between them. In this case such layer can not be utilized by GC.
|
||||
//
|
||||
// How such hole between partitions can appear?
|
||||
// if we have relation with relid=1 and size 100 and relation with relid=2 with size 200 then result of
|
||||
// KeySpace::partition may contain partitions <100000000..100000099> and <200000000..200000199>.
|
||||
// If there is delta layer <100000000..300000000> then it never be garbage collected because
|
||||
// image layers <100000000..100000099> and <200000000..200000199> are not completely covering it.
|
||||
let mut start = Key::MIN;
|
||||
|
||||
for partition in partitioning.parts.iter() {
|
||||
let img_range = start..partition.ranges.last().unwrap().end;
|
||||
start = img_range.end;
|
||||
if force || self.time_for_new_image_layer(partition, lsn)? {
|
||||
let img_range =
|
||||
partition.ranges.first().unwrap().start..partition.ranges.last().unwrap().end;
|
||||
let mut image_layer_writer = ImageLayerWriter::new(
|
||||
self.conf,
|
||||
self.timeline_id,
|
||||
@@ -2750,6 +2732,7 @@ impl Timeline {
|
||||
"failpoint image-layer-writer-fail-before-finish"
|
||||
)))
|
||||
});
|
||||
|
||||
for range in &partition.ranges {
|
||||
let mut key = range.start;
|
||||
while key < range.end {
|
||||
@@ -3164,7 +3147,9 @@ impl Timeline {
|
||||
}
|
||||
|
||||
fail_point!("delta-layer-writer-fail-before-finish", |_| {
|
||||
Err(anyhow::anyhow!("failpoint delta-layer-writer-fail-before-finish").into())
|
||||
return Err(
|
||||
anyhow::anyhow!("failpoint delta-layer-writer-fail-before-finish").into(),
|
||||
);
|
||||
});
|
||||
|
||||
writer.as_mut().unwrap().put_value(key, lsn, value)?;
|
||||
@@ -3834,7 +3819,7 @@ impl Timeline {
|
||||
remote_layer.ongoing_download.close();
|
||||
} else {
|
||||
// Keep semaphore open. We'll drop the permit at the end of the function.
|
||||
error!("on-demand download failed: {:?}", result.as_ref().unwrap_err());
|
||||
info!("on-demand download failed: {:?}", result.as_ref().unwrap_err());
|
||||
}
|
||||
|
||||
// Don't treat it as an error if the task that triggered the download
|
||||
|
||||
@@ -45,7 +45,6 @@ pub fn spawn_connection_manager_task(
|
||||
lagging_wal_timeout: Duration,
|
||||
max_lsn_wal_lag: NonZeroU64,
|
||||
auth_token: Option<Arc<String>>,
|
||||
availability_zone: Option<String>,
|
||||
ctx: RequestContext,
|
||||
) {
|
||||
let mut broker_client = get_broker_client().clone();
|
||||
@@ -68,7 +67,6 @@ pub fn spawn_connection_manager_task(
|
||||
lagging_wal_timeout,
|
||||
max_lsn_wal_lag,
|
||||
auth_token,
|
||||
availability_zone,
|
||||
);
|
||||
loop {
|
||||
select! {
|
||||
@@ -336,7 +334,6 @@ struct WalreceiverState {
|
||||
/// Data about all timelines, available for connection, fetched from storage broker, grouped by their corresponding safekeeper node id.
|
||||
wal_stream_candidates: HashMap<NodeId, BrokerSkTimeline>,
|
||||
auth_token: Option<Arc<String>>,
|
||||
availability_zone: Option<String>,
|
||||
}
|
||||
|
||||
/// Current connection data.
|
||||
@@ -384,7 +381,6 @@ impl WalreceiverState {
|
||||
lagging_wal_timeout: Duration,
|
||||
max_lsn_wal_lag: NonZeroU64,
|
||||
auth_token: Option<Arc<String>>,
|
||||
availability_zone: Option<String>,
|
||||
) -> Self {
|
||||
let id = TenantTimelineId {
|
||||
tenant_id: timeline.tenant_id,
|
||||
@@ -400,7 +396,6 @@ impl WalreceiverState {
|
||||
wal_stream_candidates: HashMap::new(),
|
||||
wal_connection_retries: HashMap::new(),
|
||||
auth_token,
|
||||
availability_zone,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -745,7 +740,6 @@ impl WalreceiverState {
|
||||
None => None,
|
||||
Some(x) => Some(x),
|
||||
},
|
||||
self.availability_zone.as_deref(),
|
||||
) {
|
||||
Ok(connstr) => Some((*sk_id, info, connstr)),
|
||||
Err(e) => {
|
||||
@@ -830,24 +824,17 @@ fn wal_stream_connection_config(
|
||||
}: TenantTimelineId,
|
||||
listen_pg_addr_str: &str,
|
||||
auth_token: Option<&str>,
|
||||
availability_zone: Option<&str>,
|
||||
) -> anyhow::Result<PgConnectionConfig> {
|
||||
let (host, port) =
|
||||
parse_host_port(listen_pg_addr_str).context("Unable to parse listen_pg_addr_str")?;
|
||||
let port = port.unwrap_or(5432);
|
||||
let mut connstr = PgConnectionConfig::new_host_port(host, port)
|
||||
Ok(PgConnectionConfig::new_host_port(host, port)
|
||||
.extend_options([
|
||||
"-c".to_owned(),
|
||||
format!("timeline_id={}", timeline_id),
|
||||
format!("tenant_id={}", tenant_id),
|
||||
])
|
||||
.set_password(auth_token.map(|s| s.to_owned()));
|
||||
|
||||
if let Some(availability_zone) = availability_zone {
|
||||
connstr = connstr.extend_options([format!("availability_zone={}", availability_zone)]);
|
||||
}
|
||||
|
||||
Ok(connstr)
|
||||
.set_password(auth_token.map(|s| s.to_owned())))
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
@@ -1286,7 +1273,6 @@ mod tests {
|
||||
wal_stream_candidates: HashMap::new(),
|
||||
wal_connection_retries: HashMap::new(),
|
||||
auth_token: None,
|
||||
availability_zone: None,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -33,11 +33,10 @@ use crate::{
|
||||
walingest::WalIngest,
|
||||
walrecord::DecodedWALRecord,
|
||||
};
|
||||
use postgres_backend::is_expected_io_error;
|
||||
use postgres_connection::PgConnectionConfig;
|
||||
use postgres_ffi::waldecoder::WalStreamDecoder;
|
||||
use pq_proto::ReplicationFeedback;
|
||||
use utils::lsn::Lsn;
|
||||
use utils::{lsn::Lsn, postgres_backend_async::is_expected_io_error};
|
||||
|
||||
/// Status of the connection.
|
||||
#[derive(Debug, Clone, Copy)]
|
||||
@@ -354,7 +353,7 @@ pub async fn handle_walreceiver_connection(
|
||||
debug!("neon_status_update {status_update:?}");
|
||||
|
||||
let mut data = BytesMut::new();
|
||||
status_update.serialize(&mut data);
|
||||
status_update.serialize(&mut data)?;
|
||||
physical_stream
|
||||
.as_mut()
|
||||
.zenith_status_update(data.len() as u64, &data)
|
||||
@@ -435,8 +434,8 @@ fn ignore_expected_errors(pg_error: postgres::Error) -> anyhow::Result<postgres:
|
||||
{
|
||||
return Ok(pg_error);
|
||||
} else if let Some(db_error) = pg_error.as_db_error() {
|
||||
if db_error.code() == &SqlState::SUCCESSFUL_COMPLETION
|
||||
&& db_error.message().contains("ending streaming")
|
||||
if db_error.code() == &SqlState::CONNECTION_FAILURE
|
||||
&& db_error.message().contains("end streaming")
|
||||
{
|
||||
return Ok(pg_error);
|
||||
}
|
||||
|
||||
@@ -23,11 +23,13 @@ use bytes::{BufMut, Bytes, BytesMut};
|
||||
use nix::poll::*;
|
||||
use serde::Serialize;
|
||||
use std::collections::VecDeque;
|
||||
use std::fs::OpenOptions;
|
||||
use std::io::prelude::*;
|
||||
use std::io::{Error, ErrorKind};
|
||||
use std::ops::{Deref, DerefMut};
|
||||
use std::os::unix::io::{AsRawFd, RawFd};
|
||||
use std::os::unix::prelude::CommandExt;
|
||||
use std::path::PathBuf;
|
||||
use std::process::Stdio;
|
||||
use std::process::{Child, ChildStderr, ChildStdin, ChildStdout, Command};
|
||||
use std::sync::{Mutex, MutexGuard};
|
||||
@@ -254,53 +256,52 @@ impl PostgresRedoManager {
|
||||
pg_version: u32,
|
||||
) -> Result<Bytes, WalRedoError> {
|
||||
let (rel, blknum) = key_to_rel_block(key).or(Err(WalRedoError::InvalidRecord))?;
|
||||
const MAX_RETRY_ATTEMPTS: u32 = 1;
|
||||
|
||||
let start_time = Instant::now();
|
||||
let mut n_attempts = 0u32;
|
||||
loop {
|
||||
let mut proc = self.stdin.lock().unwrap();
|
||||
let lock_time = Instant::now();
|
||||
|
||||
// launch the WAL redo process on first use
|
||||
if proc.is_none() {
|
||||
self.launch(&mut proc, pg_version)?;
|
||||
}
|
||||
WAL_REDO_WAIT_TIME.observe(lock_time.duration_since(start_time).as_secs_f64());
|
||||
let mut proc = self.stdin.lock().unwrap();
|
||||
let lock_time = Instant::now();
|
||||
|
||||
// Relational WAL records are applied using wal-redo-postgres
|
||||
let buf_tag = BufferTag { rel, blknum };
|
||||
let result = self
|
||||
.apply_wal_records(proc, buf_tag, &base_img, records, wal_redo_timeout)
|
||||
.map_err(WalRedoError::IoError);
|
||||
// launch the WAL redo process on first use
|
||||
if proc.is_none() {
|
||||
self.launch(&mut proc, pg_version)?;
|
||||
}
|
||||
WAL_REDO_WAIT_TIME.observe(lock_time.duration_since(start_time).as_secs_f64());
|
||||
|
||||
let end_time = Instant::now();
|
||||
let duration = end_time.duration_since(lock_time);
|
||||
// Relational WAL records are applied using wal-redo-postgres
|
||||
let buf_tag = BufferTag { rel, blknum };
|
||||
let result = self
|
||||
.apply_wal_records(proc, buf_tag, base_img, records, wal_redo_timeout)
|
||||
.map_err(WalRedoError::IoError);
|
||||
|
||||
let len = records.len();
|
||||
let nbytes = records.iter().fold(0, |acumulator, record| {
|
||||
acumulator
|
||||
+ match &record.1 {
|
||||
NeonWalRecord::Postgres { rec, .. } => rec.len(),
|
||||
_ => unreachable!("Only PostgreSQL records are accepted in this batch"),
|
||||
}
|
||||
});
|
||||
let end_time = Instant::now();
|
||||
let duration = end_time.duration_since(lock_time);
|
||||
|
||||
WAL_REDO_TIME.observe(duration.as_secs_f64());
|
||||
WAL_REDO_RECORDS_HISTOGRAM.observe(len as f64);
|
||||
WAL_REDO_BYTES_HISTOGRAM.observe(nbytes as f64);
|
||||
let len = records.len();
|
||||
let nbytes = records.iter().fold(0, |acumulator, record| {
|
||||
acumulator
|
||||
+ match &record.1 {
|
||||
NeonWalRecord::Postgres { rec, .. } => rec.len(),
|
||||
_ => unreachable!("Only PostgreSQL records are accepted in this batch"),
|
||||
}
|
||||
});
|
||||
|
||||
debug!(
|
||||
"postgres applied {} WAL records ({} bytes) in {} us to reconstruct page image at LSN {}",
|
||||
len,
|
||||
nbytes,
|
||||
duration.as_micros(),
|
||||
lsn
|
||||
);
|
||||
WAL_REDO_TIME.observe(duration.as_secs_f64());
|
||||
WAL_REDO_RECORDS_HISTOGRAM.observe(len as f64);
|
||||
WAL_REDO_BYTES_HISTOGRAM.observe(nbytes as f64);
|
||||
|
||||
// If something went wrong, don't try to reuse the process. Kill it, and
|
||||
// next request will launch a new one.
|
||||
if result.is_err() {
|
||||
error!(
|
||||
debug!(
|
||||
"postgres applied {} WAL records ({} bytes) in {} us to reconstruct page image at LSN {}",
|
||||
len,
|
||||
nbytes,
|
||||
duration.as_micros(),
|
||||
lsn
|
||||
);
|
||||
|
||||
// If something went wrong, don't try to reuse the process. Kill it, and
|
||||
// next request will launch a new one.
|
||||
if result.is_err() {
|
||||
error!(
|
||||
"error applying {} WAL records {}..{} ({} bytes) to base image with LSN {} to reconstruct page image at LSN {}",
|
||||
records.len(),
|
||||
records.first().map(|p| p.0).unwrap_or(Lsn(0)),
|
||||
@@ -309,28 +310,24 @@ impl PostgresRedoManager {
|
||||
base_img_lsn,
|
||||
lsn
|
||||
);
|
||||
// self.stdin only holds stdin & stderr as_raw_fd().
|
||||
// Dropping it as part of take() doesn't close them.
|
||||
// The owning objects (ChildStdout and ChildStderr) are stored in
|
||||
// self.stdout and self.stderr, respsectively.
|
||||
// We intentionally keep them open here to avoid a race between
|
||||
// currently running `apply_wal_records()` and a `launch()` call
|
||||
// after we return here.
|
||||
// The currently running `apply_wal_records()` must not read from
|
||||
// the newly launched process.
|
||||
// By keeping self.stdout and self.stderr open here, `launch()` will
|
||||
// get other file descriptors for the new child's stdout and stderr,
|
||||
// and hence the current `apply_wal_records()` calls will observe
|
||||
// `output.stdout.as_raw_fd() != stdout_fd` .
|
||||
if let Some(proc) = self.stdin.lock().unwrap().take() {
|
||||
proc.child.kill_and_wait();
|
||||
}
|
||||
}
|
||||
n_attempts += 1;
|
||||
if n_attempts > MAX_RETRY_ATTEMPTS || result.is_ok() {
|
||||
return result;
|
||||
// self.stdin only holds stdin & stderr as_raw_fd().
|
||||
// Dropping it as part of take() doesn't close them.
|
||||
// The owning objects (ChildStdout and ChildStderr) are stored in
|
||||
// self.stdout and self.stderr, respsectively.
|
||||
// We intentionally keep them open here to avoid a race between
|
||||
// currently running `apply_wal_records()` and a `launch()` call
|
||||
// after we return here.
|
||||
// The currently running `apply_wal_records()` must not read from
|
||||
// the newly launched process.
|
||||
// By keeping self.stdout and self.stderr open here, `launch()` will
|
||||
// get other file descriptors for the new child's stdout and stderr,
|
||||
// and hence the current `apply_wal_records()` calls will observe
|
||||
// `output.stdout.as_raw_fd() != stdout_fd` .
|
||||
if let Some(proc) = self.stdin.lock().unwrap().take() {
|
||||
proc.child.kill_and_wait();
|
||||
}
|
||||
}
|
||||
result
|
||||
}
|
||||
|
||||
///
|
||||
@@ -637,26 +634,26 @@ impl PostgresRedoManager {
|
||||
input: &mut MutexGuard<Option<ProcessInput>>,
|
||||
pg_version: u32,
|
||||
) -> Result<(), Error> {
|
||||
// Previous versions of wal-redo required data directory and that directories
|
||||
// occupied some space on disk. Remove it if we face it.
|
||||
//
|
||||
// This code could be dropped after one release cycle.
|
||||
let legacy_datadir = path_with_suffix_extension(
|
||||
// FIXME: We need a dummy Postgres cluster to run the process in. Currently, we
|
||||
// just create one with constant name. That fails if you try to launch more than
|
||||
// one WAL redo manager concurrently.
|
||||
let datadir = path_with_suffix_extension(
|
||||
self.conf
|
||||
.tenant_path(&self.tenant_id)
|
||||
.join("wal-redo-datadir"),
|
||||
TEMP_FILE_SUFFIX,
|
||||
);
|
||||
if legacy_datadir.exists() {
|
||||
info!("legacy wal-redo datadir {legacy_datadir:?} exists, removing");
|
||||
fs::remove_dir_all(&legacy_datadir).map_err(|e| {
|
||||
|
||||
// Create empty data directory for wal-redo postgres, deleting old one first.
|
||||
if datadir.exists() {
|
||||
info!("old temporary datadir {datadir:?} exists, removing");
|
||||
fs::remove_dir_all(&datadir).map_err(|e| {
|
||||
Error::new(
|
||||
e.kind(),
|
||||
format!("legacy wal-redo datadir {legacy_datadir:?} removal failure: {e}"),
|
||||
format!("Old temporary dir {datadir:?} removal failure: {e}"),
|
||||
)
|
||||
})?;
|
||||
}
|
||||
|
||||
let pg_bin_dir_path = self
|
||||
.conf
|
||||
.pg_bin_dir(pg_version)
|
||||
@@ -666,6 +663,35 @@ impl PostgresRedoManager {
|
||||
.pg_lib_dir(pg_version)
|
||||
.map_err(|e| Error::new(ErrorKind::Other, format!("incorrect pg_lib_dir path: {e}")))?;
|
||||
|
||||
info!("running initdb in {}", datadir.display());
|
||||
let initdb = Command::new(pg_bin_dir_path.join("initdb"))
|
||||
.args(["-D", &datadir.to_string_lossy()])
|
||||
.arg("-N")
|
||||
.env_clear()
|
||||
.env("LD_LIBRARY_PATH", &pg_lib_dir_path)
|
||||
.env("DYLD_LIBRARY_PATH", &pg_lib_dir_path) // macOS
|
||||
.close_fds()
|
||||
.output()
|
||||
.map_err(|e| Error::new(e.kind(), format!("failed to execute initdb: {e}")))?;
|
||||
|
||||
if !initdb.status.success() {
|
||||
return Err(Error::new(
|
||||
ErrorKind::Other,
|
||||
format!(
|
||||
"initdb failed\nstdout: {}\nstderr:\n{}",
|
||||
String::from_utf8_lossy(&initdb.stdout),
|
||||
String::from_utf8_lossy(&initdb.stderr)
|
||||
),
|
||||
));
|
||||
} else {
|
||||
// Limit shared cache for wal-redo-postgres
|
||||
let mut config = OpenOptions::new()
|
||||
.append(true)
|
||||
.open(PathBuf::from(&datadir).join("postgresql.conf"))?;
|
||||
config.write_all(b"shared_buffers=128kB\n")?;
|
||||
config.write_all(b"fsync=off\n")?;
|
||||
}
|
||||
|
||||
// Start postgres itself
|
||||
let child = Command::new(pg_bin_dir_path.join("postgres"))
|
||||
.arg("--wal-redo")
|
||||
@@ -675,6 +701,7 @@ impl PostgresRedoManager {
|
||||
.env_clear()
|
||||
.env("LD_LIBRARY_PATH", &pg_lib_dir_path)
|
||||
.env("DYLD_LIBRARY_PATH", &pg_lib_dir_path)
|
||||
.env("PGDATA", &datadir)
|
||||
// The redo process is not trusted, and runs in seccomp mode that
|
||||
// doesn't allow it to open any files. We have to also make sure it
|
||||
// doesn't inherit any file descriptors from the pageserver, that
|
||||
@@ -744,7 +771,7 @@ impl PostgresRedoManager {
|
||||
&self,
|
||||
mut input: MutexGuard<Option<ProcessInput>>,
|
||||
tag: BufferTag,
|
||||
base_img: &Option<Bytes>,
|
||||
base_img: Option<Bytes>,
|
||||
records: &[(Lsn, NeonWalRecord)],
|
||||
wal_redo_timeout: Duration,
|
||||
) -> Result<Bytes, std::io::Error> {
|
||||
@@ -760,7 +787,7 @@ impl PostgresRedoManager {
|
||||
let mut writebuf: Vec<u8> = Vec::with_capacity((BLCKSZ as usize) * 3);
|
||||
build_begin_redo_for_block_msg(tag, &mut writebuf);
|
||||
if let Some(img) = base_img {
|
||||
build_push_page_msg(tag, img, &mut writebuf);
|
||||
build_push_page_msg(tag, &img, &mut writebuf);
|
||||
}
|
||||
for (lsn, rec) in records.iter() {
|
||||
if let NeonWalRecord::Postgres {
|
||||
|
||||
@@ -13,14 +13,13 @@ OBJS = \
|
||||
walproposer.o \
|
||||
walproposer_utils.o
|
||||
|
||||
PG_CPPFLAGS = -I$(libpq_srcdir)
|
||||
SHLIB_LINK_INTERNAL = $(libpq)
|
||||
PG_CPPFLAGS = -I$(libpq_srcdir) -DSIMLIB
|
||||
PG_LIBS = $(libpq)
|
||||
|
||||
EXTENSION = neon
|
||||
DATA = neon--1.0.sql
|
||||
PGFILEDESC = "neon - cloud storage for PostgreSQL"
|
||||
|
||||
|
||||
PG_CONFIG = pg_config
|
||||
PGXS := $(shell $(PG_CONFIG) --pgxs)
|
||||
include $(PGXS)
|
||||
|
||||
@@ -32,9 +32,6 @@
|
||||
|
||||
#define PageStoreTrace DEBUG5
|
||||
|
||||
#define MAX_RECONNECT_ATTEMPTS 5
|
||||
#define RECONNECT_INTERVAL_USEC 1000000
|
||||
|
||||
bool connected = false;
|
||||
PGconn *pageserver_conn = NULL;
|
||||
|
||||
@@ -55,8 +52,8 @@ int readahead_buffer_size = 128;
|
||||
|
||||
static void pageserver_flush(void);
|
||||
|
||||
static bool
|
||||
pageserver_connect(int elevel)
|
||||
static void
|
||||
pageserver_connect()
|
||||
{
|
||||
char *query;
|
||||
int ret;
|
||||
@@ -72,11 +69,10 @@ pageserver_connect(int elevel)
|
||||
PQfinish(pageserver_conn);
|
||||
pageserver_conn = NULL;
|
||||
|
||||
ereport(elevel,
|
||||
ereport(ERROR,
|
||||
(errcode(ERRCODE_SQLCLIENT_UNABLE_TO_ESTABLISH_SQLCONNECTION),
|
||||
errmsg(NEON_TAG "could not establish connection to pageserver"),
|
||||
errdetail_internal("%s", msg)));
|
||||
return false;
|
||||
}
|
||||
|
||||
query = psprintf("pagestream %s %s", neon_tenant, neon_timeline);
|
||||
@@ -85,8 +81,7 @@ pageserver_connect(int elevel)
|
||||
{
|
||||
PQfinish(pageserver_conn);
|
||||
pageserver_conn = NULL;
|
||||
neon_log(elevel, "could not send pagestream command to pageserver");
|
||||
return false;
|
||||
neon_log(ERROR, "could not send pagestream command to pageserver");
|
||||
}
|
||||
|
||||
pageserver_conn_wes = CreateWaitEventSet(TopMemoryContext, 3);
|
||||
@@ -118,9 +113,8 @@ pageserver_connect(int elevel)
|
||||
FreeWaitEventSet(pageserver_conn_wes);
|
||||
pageserver_conn_wes = NULL;
|
||||
|
||||
neon_log(elevel, "could not complete handshake with pageserver: %s",
|
||||
neon_log(ERROR, "could not complete handshake with pageserver: %s",
|
||||
msg);
|
||||
return false;
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -128,7 +122,6 @@ pageserver_connect(int elevel)
|
||||
neon_log(LOG, "libpagestore: connected to '%s'", page_server_connstring_raw);
|
||||
|
||||
connected = true;
|
||||
return true;
|
||||
}
|
||||
|
||||
/*
|
||||
@@ -156,11 +149,8 @@ retry:
|
||||
if (event.events & WL_SOCKET_READABLE)
|
||||
{
|
||||
if (!PQconsumeInput(pageserver_conn))
|
||||
{
|
||||
neon_log(LOG, "could not get response from pageserver: %s",
|
||||
neon_log(ERROR, "could not get response from pageserver: %s",
|
||||
PQerrorMessage(pageserver_conn));
|
||||
return -1;
|
||||
}
|
||||
}
|
||||
|
||||
goto retry;
|
||||
@@ -200,62 +190,31 @@ static void
|
||||
pageserver_send(NeonRequest * request)
|
||||
{
|
||||
StringInfoData req_buff;
|
||||
int n_reconnect_attempts = 0;
|
||||
|
||||
/* If the connection was lost for some reason, reconnect */
|
||||
if (connected && PQstatus(pageserver_conn) == CONNECTION_BAD)
|
||||
pageserver_disconnect();
|
||||
|
||||
if (!connected)
|
||||
pageserver_connect();
|
||||
|
||||
req_buff = nm_pack_request(request);
|
||||
|
||||
/*
|
||||
* If pageserver is stopped, the connections from compute node are broken.
|
||||
* The compute node doesn't notice that immediately, but it will cause the next request to fail, usually on the next query.
|
||||
* That causes user-visible errors if pageserver is restarted, or the tenant is moved from one pageserver to another.
|
||||
* See https://github.com/neondatabase/neon/issues/1138
|
||||
* So try to reestablish connection in case of failure.
|
||||
* Send request.
|
||||
*
|
||||
* In principle, this could block if the output buffer is full, and we
|
||||
* should use async mode and check for interrupts while waiting. In
|
||||
* practice, our requests are small enough to always fit in the output and
|
||||
* TCP buffer.
|
||||
*/
|
||||
while (true)
|
||||
if (PQputCopyData(pageserver_conn, req_buff.data, req_buff.len) <= 0)
|
||||
{
|
||||
if (!connected)
|
||||
{
|
||||
if (!pageserver_connect(n_reconnect_attempts < MAX_RECONNECT_ATTEMPTS ? LOG : ERROR))
|
||||
{
|
||||
n_reconnect_attempts += 1;
|
||||
pg_usleep(RECONNECT_INTERVAL_USEC);
|
||||
continue;
|
||||
}
|
||||
}
|
||||
char *msg = pchomp(PQerrorMessage(pageserver_conn));
|
||||
|
||||
/*
|
||||
* Send request.
|
||||
*
|
||||
* In principle, this could block if the output buffer is full, and we
|
||||
* should use async mode and check for interrupts while waiting. In
|
||||
* practice, our requests are small enough to always fit in the output and
|
||||
* TCP buffer.
|
||||
*/
|
||||
if (PQputCopyData(pageserver_conn, req_buff.data, req_buff.len) <= 0)
|
||||
{
|
||||
char *msg = pchomp(PQerrorMessage(pageserver_conn));
|
||||
if (n_reconnect_attempts < MAX_RECONNECT_ATTEMPTS)
|
||||
{
|
||||
neon_log(LOG, "failed to send page request (try to reconnect): %s", msg);
|
||||
if (n_reconnect_attempts != 0) /* do not sleep before first reconnect attempt, assuming that pageserver is already restarted */
|
||||
pg_usleep(RECONNECT_INTERVAL_USEC);
|
||||
n_reconnect_attempts += 1;
|
||||
continue;
|
||||
}
|
||||
else
|
||||
{
|
||||
pageserver_disconnect();
|
||||
neon_log(ERROR, "failed to send page request: %s", msg);
|
||||
}
|
||||
}
|
||||
break;
|
||||
pageserver_disconnect();
|
||||
neon_log(ERROR, "failed to send page request: %s", msg);
|
||||
}
|
||||
|
||||
pfree(req_buff.data);
|
||||
|
||||
n_unflushed_requests++;
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user