Compare commits

..

50 Commits

Author SHA1 Message Date
Arthur Petukhovsky
6a00ad3aab Validate logs 2023-09-18 17:12:40 +00:00
Arthur Petukhovsky
61e6b24cb2 Fix fd leak 2023-09-18 09:54:19 +00:00
Arthur Petukhovsky
44c7d96ed0 Cleanup resources better 2023-09-16 23:10:53 +00:00
Arthur Petukhovsky
10ad3ae4eb Fix WAL page header 2023-09-16 20:50:10 +00:00
Arthur Petukhovsky
eb2886b401 Add test for 1000 WAL messages 2023-09-16 14:58:26 +00:00
Arthur Petukhovsky
0dc262a84a Fix bug in walproposer voting 2023-08-29 14:11:04 +00:00
Arthur Petukhovsky
d801ba7248 Print network config 2023-08-29 14:10:57 +00:00
Arthur Petukhovsky
1effb586ba Make network random unstable 2023-08-29 13:55:36 +00:00
Arthur Petukhovsky
2fd351fd63 Hide debug logs 2023-08-29 10:05:18 +00:00
Arthur Petukhovsky
13e94bf687 Fix truncateLsn bug 2023-08-29 09:03:22 +00:00
Arthur Petukhovsky
41b9750e81 Run many schedules 2023-08-24 23:42:11 +00:00
Arthur Petukhovsky
f8729f046d Fix excessive logs 2023-08-24 17:25:44 +00:00
Arthur Petukhovsky
420d3bc18f Add simulation schedule 2023-08-24 15:24:38 +00:00
Arthur Petukhovsky
33f7877d1b Show simulation time in logs 2023-08-23 10:10:11 +00:00
Arthur Petukhovsky
7de94c959a Support walproposer recovery 2023-08-22 23:15:46 +00:00
Arthur Petukhovsky
731ed3bb64 Support virtual disk in tests 2023-08-17 13:09:55 +00:00
Arthur Petukhovsky
413ce2cfe8 Crash safekeepers 2023-08-17 10:36:23 +00:00
Arthur Petukhovsky
7f36028fab Generate WAL in tests 2023-08-03 16:58:41 +00:00
Arthur Petukhovsky
cb6a8d3fe3 Fix some warnings 2023-07-28 21:37:16 +00:00
Arthur Petukhovsky
095747afc0 Fix walproposer main loop 2023-07-28 21:18:08 +00:00
Arthur Petukhovsky
89bd7ab8a3 Fix read/write in walproposer 2023-07-28 15:14:24 +00:00
Arthur Petukhovsky
5034a8cca0 WIP 2023-07-26 22:51:19 +02:00
Arthur Petukhovsky
55e40d090e Run sync several times 2023-07-25 11:16:47 +00:00
Arthur Petukhovsky
d87e822169 Return LSN from sync safekeepers 2023-07-24 21:15:35 +00:00
Arthur Petukhovsky
296a0cbac2 Add -DSIMLIB 2023-07-21 15:40:47 +00:00
Arthur Petukhovsky
aed14f52d5 Test sync safekeepers 2023-06-03 19:11:28 +00:00
Arthur Petukhovsky
909d7fadb8 Implement simlib sk server 2023-06-02 14:49:55 +00:00
Arthur Petukhovsky
3840d6b18b Clean up C API 2023-06-01 09:38:07 +00:00
Arthur Petukhovsky
65f92232e6 Compile walproposer 2023-05-31 21:06:47 +00:00
Arthur Petukhovsky
0d4f987fc8 Implement full simlib C API 2023-05-31 20:25:25 +00:00
Arthur Petukhovsky
aa0763d49d Run simulator on C code 2023-05-31 16:55:16 +00:00
Arthur Petukhovsky
7b5123edda Fix elog 2023-05-31 15:06:26 +00:00
Arthur Petukhovsky
b6a80bc269 Link postgres to rust statically 2023-05-31 13:19:41 +00:00
Arthur Petukhovsky
ac82b34c64 Create more involved example 2023-05-30 16:43:33 +00:00
Arthur Petukhovsky
a77fc2c5ff Test Rust -> C -> Rust codepath 2023-05-30 16:38:32 +00:00
Arthur Petukhovsky
9ccbec0e14 Spend some time 2023-05-26 14:45:25 +03:00
Arthur Petukhovsky
b55005d2c4 Build simple C func example 2023-05-26 14:44:48 +03:00
Arthur Petukhovsky
6436432a77 Showcase network failures 2023-05-25 12:53:20 +03:00
Arthur Petukhovsky
1b8918e665 Add accept, close and delays to the network 2023-05-25 12:26:57 +03:00
Arthur Petukhovsky
87c9edac7c Add basic support for network delays 2023-05-24 20:28:53 +03:00
Arthur Petukhovsky
5e0550a620 Add os.sleep and os.random 2023-05-24 15:51:30 +03:00
Arthur Petukhovsky
06f493f525 Extract simlib 2023-05-24 13:06:42 +03:00
Arthur Petukhovsky
f6b540ebfe Add initial support for virtual time 2023-05-22 15:00:56 +03:00
Arthur Petukhovsky
83f87af02b Remove sync debug 2023-03-10 00:10:09 +02:00
Arthur Petukhovsky
79823c38cd It looks deterministic now 2023-03-10 00:03:35 +02:00
Arthur Petukhovsky
072fb3d7e9 WIP 2023-03-09 14:59:03 +02:00
Arthur Petukhovsky
f2fb9f6be9 WIP 2023-03-09 14:51:29 +02:00
Arthur Petukhovsky
dd4c8fb568 WIP 2023-03-09 00:51:14 +02:00
Arthur Petukhovsky
9116c01614 WIP 2023-03-08 18:45:13 +02:00
Arthur Petukhovsky
17cd96e022 WIP 2023-03-03 20:33:55 +00:00
210 changed files with 9340 additions and 6160 deletions

View File

@@ -14,3 +14,6 @@ opt-level = 1
[alias]
build_testing = ["build", "--features", "testing"]
[build]
rustflags = ["-C", "default-linker-libraries"]

View File

@@ -19,7 +19,7 @@ inputs:
run_in_parallel:
description: 'Whether to run tests in parallel'
required: false
default: 'false'
default: 'true'
save_perf_report:
description: 'Whether to upload the performance report, if true PERF_TEST_RESULT_CONNSTR env variable should be set'
required: false
@@ -171,7 +171,7 @@ runs:
--junitxml=$TEST_OUTPUT/junit.xml \
--alluredir=$TEST_OUTPUT/allure/results \
--tb=short \
--verbose -k "test_forward or test_create_snapsh" -x \
--verbose \
-rA $TEST_SELECTION $EXTRA_PARAMS
if [[ "${{ inputs.save_perf_report }}" == "true" ]]; then

View File

@@ -91,15 +91,6 @@
tags:
- pageserver
# used in `pageserver.service` template
- name: learn current availability_zone
shell:
cmd: "curl -s http://169.254.169.254/latest/meta-data/placement/availability-zone"
register: ec2_availability_zone
- set_fact:
ec2_availability_zone={{ ec2_availability_zone.stdout }}
- name: upload systemd service definition
ansible.builtin.template:
src: systemd/pageserver.service
@@ -127,7 +118,7 @@
cmd: |
INSTANCE_ID=$(curl -s http://169.254.169.254/latest/meta-data/instance-id)
curl -sfS -H "Authorization: Bearer {{ CONSOLE_API_TOKEN }}" {{ console_mgmt_base_url }}/management/api/v2/pageservers/$INSTANCE_ID | jq '.version = {{ current_version }}' > /tmp/new_version
curl -sfS -H "Authorization: Bearer {{ CONSOLE_API_TOKEN }}" -H "Content-Type: application/json" -X POST -d@/tmp/new_version {{ console_mgmt_base_url }}/management/api/v2/pageservers
curl -sfS -H "Authorization: Bearer {{ CONSOLE_API_TOKEN }}" -X POST -d@/tmp/new_version {{ console_mgmt_base_url }}/management/api/v2/pageservers
tags:
- pageserver
@@ -162,15 +153,6 @@
tags:
- safekeeper
# used in `safekeeper.service` template
- name: learn current availability_zone
shell:
cmd: "curl -s http://169.254.169.254/latest/meta-data/placement/availability-zone"
register: ec2_availability_zone
- set_fact:
ec2_availability_zone={{ ec2_availability_zone.stdout }}
# in the future safekeepers should discover pageservers byself
# but currently use first pageserver that was discovered
- name: set first pageserver var for safekeepers
@@ -206,6 +188,6 @@
cmd: |
INSTANCE_ID=$(curl -s http://169.254.169.254/latest/meta-data/instance-id)
curl -sfS -H "Authorization: Bearer {{ CONSOLE_API_TOKEN }}" {{ console_mgmt_base_url }}/management/api/v2/safekeepers/$INSTANCE_ID | jq '.version = {{ current_version }}' > /tmp/new_version
curl -sfS -H "Authorization: Bearer {{ CONSOLE_API_TOKEN }}" -H "Content-Type: application/json" -X POST -d@/tmp/new_version {{ console_mgmt_base_url }}/management/api/v2/safekeepers
curl -sfS -H "Authorization: Bearer {{ CONSOLE_API_TOKEN }}" -X POST -d@/tmp/new_version {{ console_mgmt_base_url }}/management/api/v2/safekeepers
tags:
- safekeeper

View File

@@ -27,8 +27,6 @@ storage:
ansible_host: i-0cd8d316ecbb715be
pageserver-1.eu-central-1.aws.neon.tech:
ansible_host: i-090044ed3d383fef0
pageserver-2.eu-central-1.aws.neon.tech:
ansible_host: i-033584edf3f4b6742
safekeepers:
hosts:

View File

@@ -26,7 +26,7 @@ EOF
if ! curl -sf -H "Authorization: Bearer {{ CONSOLE_API_TOKEN }}" {{ console_mgmt_base_url }}/management/api/v2/pageservers/${INSTANCE_ID} -o /dev/null; then
# not registered, so register it now
ID=$(curl -sf -X POST -H "Authorization: Bearer {{ CONSOLE_API_TOKEN }}" -H "Content-Type: application/json" {{ console_mgmt_base_url }}/management/api/v2/pageservers -d@/tmp/payload | jq -r '.id')
ID=$(curl -sf -X POST -H "Authorization: Bearer {{ CONSOLE_API_TOKEN }}" {{ console_mgmt_base_url }}/management/api/v2/pageservers -d@/tmp/payload | jq -r '.id')
# init pageserver
sudo -u pageserver /usr/local/bin/pageserver -c "id=${ID}" -c "pg_distrib_dir='/usr/local'" --init -D /storage/pageserver/data

View File

@@ -25,7 +25,7 @@ EOF
if ! curl -sf -H "Authorization: Bearer {{ CONSOLE_API_TOKEN }}" {{ console_mgmt_base_url }}/management/api/v2/safekeepers/${INSTANCE_ID} -o /dev/null; then
# not registered, so register it now
ID=$(curl -sf -X POST -H "Authorization: Bearer {{ CONSOLE_API_TOKEN }}" -H "Content-Type: application/json" {{ console_mgmt_base_url }}/management/api/v2/safekeepers -d@/tmp/payload | jq -r '.id')
ID=$(curl -sf -X POST -H "Authorization: Bearer {{ CONSOLE_API_TOKEN }}" {{ console_mgmt_base_url }}/management/api/v2/safekeepers -d@/tmp/payload | jq -r '.id')
# init safekeeper
sudo -u safekeeper /usr/local/bin/safekeeper --id ${ID} --init -D /storage/safekeeper/data
fi

View File

@@ -6,7 +6,7 @@ After=network.target auditd.service
Type=simple
User=pageserver
Environment=RUST_BACKTRACE=1 NEON_REPO_DIR=/storage/pageserver LD_LIBRARY_PATH=/usr/local/v14/lib SENTRY_DSN={{ SENTRY_URL_PAGESERVER }} SENTRY_ENVIRONMENT={{ sentry_environment }}
ExecStart=/usr/local/bin/pageserver -c "pg_distrib_dir='/usr/local'" -c "listen_pg_addr='0.0.0.0:6400'" -c "listen_http_addr='0.0.0.0:9898'" -c "broker_endpoint='{{ broker_endpoint }}'" -c "availability_zone='{{ ec2_availability_zone }}'" -D /storage/pageserver/data
ExecStart=/usr/local/bin/pageserver -c "pg_distrib_dir='/usr/local'" -c "listen_pg_addr='0.0.0.0:6400'" -c "listen_http_addr='0.0.0.0:9898'" -c "broker_endpoint='{{ broker_endpoint }}'" -D /storage/pageserver/data
ExecReload=/bin/kill -HUP $MAINPID
KillMode=mixed
KillSignal=SIGINT

View File

@@ -6,7 +6,7 @@ After=network.target auditd.service
Type=simple
User=safekeeper
Environment=RUST_BACKTRACE=1 NEON_REPO_DIR=/storage/safekeeper/data LD_LIBRARY_PATH=/usr/local/v14/lib SENTRY_DSN={{ SENTRY_URL_SAFEKEEPER }} SENTRY_ENVIRONMENT={{ sentry_environment }}
ExecStart=/usr/local/bin/safekeeper -l {{ inventory_hostname }}{{ hostname_suffix }}:6500 --listen-http {{ inventory_hostname }}{{ hostname_suffix }}:7676 -D /storage/safekeeper/data --broker-endpoint={{ broker_endpoint }} --remote-storage='{bucket_name="{{bucket_name}}", bucket_region="{{bucket_region}}", prefix_in_bucket="{{ safekeeper_s3_prefix }}"}' --availability-zone={{ ec2_availability_zone }}
ExecStart=/usr/local/bin/safekeeper -l {{ inventory_hostname }}{{ hostname_suffix }}:6500 --listen-http {{ inventory_hostname }}{{ hostname_suffix }}:7676 -D /storage/safekeeper/data --broker-endpoint={{ broker_endpoint }} --remote-storage='{bucket_name="{{bucket_name}}", bucket_region="{{bucket_region}}", prefix_in_bucket="{{ safekeeper_s3_prefix }}"}'
ExecReload=/bin/kill -HUP $MAINPID
KillMode=mixed
KillSignal=SIGINT

View File

@@ -1,22 +1,6 @@
# Helm chart values for neon-proxy-scram.
# This is a YAML-formatted file.
deploymentStrategy:
type: RollingUpdate
rollingUpdate:
maxSurge: 100%
maxUnavailable: 50%
# Delay the kill signal by 7 days (7 * 24 * 60 * 60)
# The pod(s) will stay in Terminating, keeps the existing connections
# but doesn't receive new ones
containerLifecycle:
preStop:
exec:
command: ["/bin/sh", "-c", "sleep 604800"]
terminationGracePeriodSeconds: 604800
image:
repository: neondatabase/neon

View File

@@ -5,7 +5,6 @@ on:
branches:
- main
- release
- tmp-repro
pull_request:
defaults:
@@ -75,12 +74,15 @@ jobs:
- name: Install Python deps
run: ./scripts/pysync
- name: Run ruff to ensure code format
run: poetry run ruff .
- name: Run isort to ensure code format
run: poetry run isort --diff --check .
- name: Run black to ensure code format
run: poetry run black --diff --check .
- name: Run flake8 to ensure code format
run: poetry run flake8 .
- name: Run mypy to check types
run: poetry run mypy .
@@ -549,48 +551,6 @@ jobs:
- name: Cleanup ECR folder
run: rm -rf ~/.ecr
neon-image-depot:
# For testing this will run side-by-side for a few merges.
# This action is not really optimized yet, but gets the job done
runs-on: [ self-hosted, gen3, small ]
needs: [ tag ]
container: 369495373322.dkr.ecr.eu-central-1.amazonaws.com/base:pinned
permissions:
contents: read
id-token: write
steps:
- name: Checkout
uses: actions/checkout@v3
with:
submodules: true
fetch-depth: 0
- name: Setup go
uses: actions/setup-go@v3
with:
go-version: '1.19'
- name: Set up Depot CLI
uses: depot/setup-action@v1
- name: Install Crane & ECR helper
run: go install github.com/awslabs/amazon-ecr-credential-helper/ecr-login/cli/docker-credential-ecr-login@69c85dc22db6511932bbf119e1a0cc5c90c69a7f # v0.6.0
- name: Configure ECR login
run: |
mkdir /github/home/.docker/
echo "{\"credsStore\":\"ecr-login\"}" > /github/home/.docker/config.json
- name: Build and push
uses: depot/build-push-action@v1
with:
# if no depot.json file is at the root of your repo, you must specify the project id
project: nrdv0s4kcs
push: true
tags: 369495373322.dkr.ecr.eu-central-1.amazonaws.com/neon:depot-${{needs.tag.outputs.build-tag}}
compute-tools-image:
runs-on: [ self-hosted, gen3, large ]
needs: [ tag ]

View File

@@ -31,4 +31,3 @@ jobs:
head: releases/${{ steps.date.outputs.date }}
base: release
title: Release ${{ steps.date.outputs.date }}
team_reviewers: release

2
.gitignore vendored
View File

@@ -18,3 +18,5 @@ test_output/
*.o
*.so
*.Po
tmp

133
Cargo.lock generated
View File

@@ -679,6 +679,25 @@ version = "0.3.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "37b2a672a2cb129a2e41c10b1224bb368f9f37a2b16b612598138befd7b37eb5"
[[package]]
name = "cbindgen"
version = "0.24.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "4b922faaf31122819ec80c4047cc684c6979a087366c069611e33649bf98e18d"
dependencies = [
"clap 3.2.23",
"heck",
"indexmap",
"log",
"proc-macro2",
"quote",
"serde",
"serde_json",
"syn",
"tempfile",
"toml",
]
[[package]]
name = "cc"
version = "1.0.79"
@@ -757,9 +776,12 @@ version = "3.2.23"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "71655c45cb9845d3270c9d6df84ebe72b4dad3c2ba3f7023ad47c144e4e473a5"
dependencies = [
"atty",
"bitflags",
"clap_lex 0.2.4",
"indexmap",
"strsim",
"termcolor",
"textwrap",
]
@@ -851,7 +873,6 @@ dependencies = [
"futures",
"hyper",
"notify",
"num_cpus",
"opentelemetry",
"postgres",
"regex",
@@ -914,7 +935,6 @@ dependencies = [
"once_cell",
"pageserver_api",
"postgres",
"postgres_backend",
"postgres_connection",
"regex",
"reqwest",
@@ -1016,6 +1036,20 @@ version = "1.1.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "6548a0ad5d2549e111e1f6a11a6c2e2d00ce6a3dafe22948d67c2b443f775e52"
[[package]]
name = "crossbeam"
version = "0.8.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "2801af0d36612ae591caa9568261fddce32ce6e08a7275ea334a06a4ad021a2c"
dependencies = [
"cfg-if",
"crossbeam-channel",
"crossbeam-deque",
"crossbeam-epoch",
"crossbeam-queue",
"crossbeam-utils",
]
[[package]]
name = "crossbeam-channel"
version = "0.5.6"
@@ -1050,6 +1084,16 @@ dependencies = [
"scopeguard",
]
[[package]]
name = "crossbeam-queue"
version = "0.3.8"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d1cfb3ea8a53f37c40dea2c7bedcbd88bdfae54f5e2175d6ecaff1c988353add"
dependencies = [
"cfg-if",
"crossbeam-utils",
]
[[package]]
name = "crossbeam-utils"
version = "0.8.14"
@@ -2456,7 +2500,6 @@ dependencies = [
"postgres",
"postgres-protocol",
"postgres-types",
"postgres_backend",
"postgres_connection",
"postgres_ffi",
"pq_proto",
@@ -2679,28 +2722,6 @@ dependencies = [
"postgres-protocol",
]
[[package]]
name = "postgres_backend"
version = "0.1.0"
dependencies = [
"anyhow",
"async-trait",
"bytes",
"futures",
"once_cell",
"pq_proto",
"rustls",
"rustls-pemfile",
"serde",
"thiserror",
"tokio",
"tokio-postgres",
"tokio-postgres-rustls",
"tokio-rustls",
"tracing",
"workspace_hack",
]
[[package]]
name = "postgres_connection"
version = "0.1.0"
@@ -2748,7 +2769,7 @@ checksum = "5b40af805b3121feab8a3c29f04d8ad262fa8e0561883e7653e024ae4479e6de"
name = "pq_proto"
version = "0.1.0"
dependencies = [
"byteorder",
"anyhow",
"bytes",
"pin-project-lite",
"postgres-protocol",
@@ -2923,7 +2944,6 @@ dependencies = [
"opentelemetry",
"parking_lot",
"pin-project-lite",
"postgres_backend",
"pq_proto",
"prometheus",
"rand",
@@ -3303,6 +3323,15 @@ dependencies = [
"base64 0.21.0",
]
[[package]]
name = "rustls-split"
version = "0.3.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "78802c9612b4689d207acff746f38132ca1b12dadb55d471aa5f10fd580f47d3"
dependencies = [
"rustls",
]
[[package]]
name = "rustversion"
version = "1.0.11"
@@ -3328,22 +3357,25 @@ dependencies = [
"clap 4.1.4",
"const_format",
"crc32c",
"crossbeam",
"fs2",
"git-version",
"hex",
"humantime",
"hyper",
"metrics",
"nix",
"once_cell",
"parking_lot",
"postgres",
"postgres-protocol",
"postgres_backend",
"postgres_ffi",
"pq_proto",
"rand",
"regex",
"remote_storage",
"safekeeper_api",
"scopeguard",
"serde",
"serde_json",
"serde_with",
@@ -4524,6 +4556,7 @@ dependencies = [
"bytes",
"criterion",
"futures",
"git-version",
"heapless",
"hex",
"hex-literal",
@@ -4532,9 +4565,12 @@ dependencies = [
"metrics",
"nix",
"once_cell",
"pin-project-lite",
"pq_proto",
"rand",
"routerify",
"rustls",
"rustls-pemfile",
"rustls-split",
"sentry",
"serde",
"serde_json",
@@ -4545,6 +4581,7 @@ dependencies = [
"tempfile",
"thiserror",
"tokio",
"tokio-rustls",
"tracing",
"tracing-subscriber",
"url",
@@ -4600,6 +4637,38 @@ dependencies = [
"winapi-util",
]
[[package]]
name = "walproposer"
version = "0.1.0"
dependencies = [
"anyhow",
"atty",
"bindgen",
"byteorder",
"bytes",
"cbindgen",
"crc32c",
"env_logger",
"hex",
"hyper",
"libc",
"log",
"memoffset 0.8.0",
"once_cell",
"postgres",
"postgres_ffi",
"rand",
"regex",
"safekeeper",
"scopeguard",
"serde",
"thiserror",
"tracing",
"tracing-subscriber",
"utils",
"workspace_hack",
]
[[package]]
name = "want"
version = "0.3.0"
@@ -4848,19 +4917,14 @@ name = "workspace_hack"
version = "0.1.0"
dependencies = [
"anyhow",
"byteorder",
"bytes",
"chrono",
"clap 4.1.4",
"crossbeam-utils",
"digest",
"either",
"fail",
"futures",
"futures-channel",
"futures-core",
"futures-executor",
"futures-sink",
"futures-util",
"hashbrown 0.12.3",
"indexmap",
@@ -4885,7 +4949,6 @@ dependencies = [
"socket2",
"syn",
"tokio",
"tokio-rustls",
"tokio-util",
"tonic",
"tower",

View File

@@ -64,7 +64,6 @@ md5 = "0.7.0"
memoffset = "0.8"
nix = "0.26"
notify = "5.0.0"
num_cpus = "1.15"
num-traits = "0.2.15"
once_cell = "1.13"
opentelemetry = "0.18.0"
@@ -134,16 +133,17 @@ heapless = { default-features=false, features=[], git = "https://github.com/japa
consumption_metrics = { version = "0.1", path = "./libs/consumption_metrics/" }
metrics = { version = "0.1", path = "./libs/metrics/" }
pageserver_api = { version = "0.1", path = "./libs/pageserver_api/" }
postgres_backend = { version = "0.1", path = "./libs/postgres_backend/" }
postgres_connection = { version = "0.1", path = "./libs/postgres_connection/" }
postgres_ffi = { version = "0.1", path = "./libs/postgres_ffi/" }
pq_proto = { version = "0.1", path = "./libs/pq_proto/" }
remote_storage = { version = "0.1", path = "./libs/remote_storage/" }
safekeeper_api = { version = "0.1", path = "./libs/safekeeper_api" }
safekeeper = { path = "./safekeeper/" }
storage_broker = { version = "0.1", path = "./storage_broker/" } # Note: main broker code is inside the binary crate, so linking with the library shouldn't be heavy.
tenant_size_model = { version = "0.1", path = "./libs/tenant_size_model/" }
tracing-utils = { version = "0.1", path = "./libs/tracing-utils/" }
utils = { version = "0.1", path = "./libs/utils/" }
walproposer = { version = "0.1", path = "./libs/walproposer/" }
## Common library dependency
workspace_hack = { version = "0.1", path = "./workspace_hack/" }

View File

@@ -39,7 +39,7 @@ ARG CACHEPOT_BUCKET=neon-github-dev
COPY --from=pg-build /home/nonroot/pg_install/v14/include/postgresql/server pg_install/v14/include/postgresql/server
COPY --from=pg-build /home/nonroot/pg_install/v15/include/postgresql/server pg_install/v15/include/postgresql/server
COPY --chown=nonroot . .
COPY . .
# Show build caching stats to check if it was used in the end.
# Has to be the part of the same RUN since cachepot daemon is killed in the end of this RUN, losing the compilation stats.

View File

@@ -255,51 +255,6 @@ RUN wget https://github.com/theory/pgtap/archive/refs/tags/v1.2.0.tar.gz -O pgta
make -j $(getconf _NPROCESSORS_ONLN) install PG_CONFIG=/usr/local/pgsql/bin/pg_config && \
echo 'trusted = true' >> /usr/local/pgsql/share/extension/pgtap.control
#########################################################################################
#
# Layer "prefix-pg-build"
# compile Prefix extension
#
#########################################################################################
FROM build-deps AS prefix-pg-build
COPY --from=pg-build /usr/local/pgsql/ /usr/local/pgsql/
RUN wget https://github.com/dimitri/prefix/archive/refs/tags/v1.2.9.tar.gz -O prefix.tar.gz && \
mkdir prefix-src && cd prefix-src && tar xvzf ../prefix.tar.gz --strip-components=1 -C . && \
make -j $(getconf _NPROCESSORS_ONLN) PG_CONFIG=/usr/local/pgsql/bin/pg_config && \
make -j $(getconf _NPROCESSORS_ONLN) install PG_CONFIG=/usr/local/pgsql/bin/pg_config && \
echo 'trusted = true' >> /usr/local/pgsql/share/extension/prefix.control
#########################################################################################
#
# Layer "hll-pg-build"
# compile hll extension
#
#########################################################################################
FROM build-deps AS hll-pg-build
COPY --from=pg-build /usr/local/pgsql/ /usr/local/pgsql/
RUN wget https://github.com/citusdata/postgresql-hll/archive/refs/tags/v2.17.tar.gz -O hll.tar.gz && \
mkdir hll-src && cd hll-src && tar xvzf ../hll.tar.gz --strip-components=1 -C . && \
make -j $(getconf _NPROCESSORS_ONLN) PG_CONFIG=/usr/local/pgsql/bin/pg_config && \
make -j $(getconf _NPROCESSORS_ONLN) install PG_CONFIG=/usr/local/pgsql/bin/pg_config && \
echo 'trusted = true' >> /usr/local/pgsql/share/extension/hll.control
#########################################################################################
#
# Layer "plpgsql-check-pg-build"
# compile plpgsql_check extension
#
#########################################################################################
FROM build-deps AS plpgsql-check-pg-build
COPY --from=pg-build /usr/local/pgsql/ /usr/local/pgsql/
RUN wget https://github.com/okbob/plpgsql_check/archive/refs/tags/v2.3.2.tar.gz -O plpgsql_check.tar.gz && \
mkdir plpgsql_check-src && cd plpgsql_check-src && tar xvzf ../plpgsql_check.tar.gz --strip-components=1 -C . && \
make -j $(getconf _NPROCESSORS_ONLN) PG_CONFIG=/usr/local/pgsql/bin/pg_config USE_PGXS=1 && \
make -j $(getconf _NPROCESSORS_ONLN) install PG_CONFIG=/usr/local/pgsql/bin/pg_config USE_PGXS=1 && \
echo 'trusted = true' >> /usr/local/pgsql/share/extension/plpgsql_check.control
#########################################################################################
#
# Layer "rust extensions"
@@ -323,7 +278,7 @@ RUN curl -sSO https://static.rust-lang.org/rustup/dist/$(uname -m)-unknown-linux
chmod +x rustup-init && \
./rustup-init -y --no-modify-path --profile minimal --default-toolchain stable && \
rm rustup-init && \
cargo install --locked --version 0.7.3 cargo-pgx && \
cargo install --git https://github.com/vadim2404/pgx --branch neon_abi_v0.6.1 --locked cargo-pgx && \
/bin/bash -c 'cargo pgx init --pg${PG_VERSION:1}=/usr/local/pgsql/bin/pg_config'
USER root
@@ -337,11 +292,11 @@ USER root
FROM rust-extensions-build AS pg-jsonschema-pg-build
# there is no release tag yet, but we need it due to the superuser fix in the control file
RUN wget https://github.com/supabase/pg_jsonschema/archive/caeab60d70b2fd3ae421ec66466a3abbb37b7ee6.tar.gz -O pg_jsonschema.tar.gz && \
mkdir pg_jsonschema-src && cd pg_jsonschema-src && tar xvzf ../pg_jsonschema.tar.gz --strip-components=1 -C . && \
sed -i 's/pgx = "0.7.1"/pgx = { version = "0.7.3", features = [ "unsafe-postgres" ] }/g' Cargo.toml && \
RUN git clone --depth=1 --single-branch --branch neon_abi_v0.1.4 https://github.com/vadim2404/pg_jsonschema/ && \
cd pg_jsonschema && \
cargo pgx install --release && \
# it's needed to enable extension because it uses untrusted C language
sed -i 's/superuser = false/superuser = true/g' /usr/local/pgsql/share/extension/pg_jsonschema.control && \
echo "trusted = true" >> /usr/local/pgsql/share/extension/pg_jsonschema.control
#########################################################################################
@@ -353,32 +308,13 @@ RUN wget https://github.com/supabase/pg_jsonschema/archive/caeab60d70b2fd3ae421e
FROM rust-extensions-build AS pg-graphql-pg-build
# Currently pgx version bump to >= 0.7.2 causes "call to unsafe function" compliation errors in
# pgx-contrib-spiext. There is a branch that removes that dependency, so use it. It is on the
# same 1.1 version we've used before.
RUN git clone -b remove-pgx-contrib-spiext --single-branch https://github.com/yrashk/pg_graphql && \
cd pg_graphql && \
sed -i 's/pgx = "~0.7.1"/pgx = { version = "0.7.3", features = [ "unsafe-postgres" ] }/g' Cargo.toml && \
sed -i 's/pgx-tests = "~0.7.1"/pgx-tests = "0.7.3"/g' Cargo.toml && \
RUN git clone --depth=1 --single-branch --branch neon_abi_v1.1.0 https://github.com/vadim2404/pg_graphql && \
cd pg_graphql && \
cargo pgx install --release && \
# it's needed to enable extension because it uses untrusted C language
sed -i 's/superuser = false/superuser = true/g' /usr/local/pgsql/share/extension/pg_graphql.control && \
echo "trusted = true" >> /usr/local/pgsql/share/extension/pg_graphql.control
#########################################################################################
#
# Layer "pg-tiktoken-build"
# Compile "pg_tiktoken" extension
#
#########################################################################################
FROM rust-extensions-build AS pg-tiktoken-pg-build
RUN git clone --depth=1 --single-branch https://github.com/kelvich/pg_tiktoken && \
cd pg_tiktoken && \
cargo pgx install --release && \
echo "trusted = true" >> /usr/local/pgsql/share/extension/pg_tiktoken.control
#########################################################################################
#
# Layer "neon-pg-ext-build"
@@ -396,23 +332,15 @@ COPY --from=vector-pg-build /usr/local/pgsql/ /usr/local/pgsql/
COPY --from=pgjwt-pg-build /usr/local/pgsql/ /usr/local/pgsql/
COPY --from=pg-jsonschema-pg-build /usr/local/pgsql/ /usr/local/pgsql/
COPY --from=pg-graphql-pg-build /usr/local/pgsql/ /usr/local/pgsql/
COPY --from=pg-tiktoken-pg-build /usr/local/pgsql/ /usr/local/pgsql/
COPY --from=hypopg-pg-build /usr/local/pgsql/ /usr/local/pgsql/
COPY --from=pg-hashids-pg-build /usr/local/pgsql/ /usr/local/pgsql/
COPY --from=rum-pg-build /usr/local/pgsql/ /usr/local/pgsql/
COPY --from=pgtap-pg-build /usr/local/pgsql/ /usr/local/pgsql/
COPY --from=prefix-pg-build /usr/local/pgsql/ /usr/local/pgsql/
COPY --from=hll-pg-build /usr/local/pgsql/ /usr/local/pgsql/
COPY --from=plpgsql-check-pg-build /usr/local/pgsql/ /usr/local/pgsql/
COPY pgxn/ pgxn/
RUN make -j $(getconf _NPROCESSORS_ONLN) \
PG_CONFIG=/usr/local/pgsql/bin/pg_config \
-C pgxn/neon \
-s install && \
make -j $(getconf _NPROCESSORS_ONLN) \
PG_CONFIG=/usr/local/pgsql/bin/pg_config \
-C pgxn/neon_utils \
-s install
#########################################################################################
@@ -467,7 +395,7 @@ COPY --from=compute-tools --chown=postgres /home/nonroot/target/release-line-deb
# Install:
# libreadline8 for psql
# libicu67, locales for collations (including ICU and plpgsql_check)
# libicu67, locales for collations (including ICU)
# libossp-uuid16 for extension ossp-uuid
# libgeos, libgdal, libsfcgal1, libproj and libprotobuf-c1 for PostGIS
# libxml2, libxslt1.1 for xml2

View File

@@ -1,70 +1,25 @@
# Note: this file *mostly* just builds on Dockerfile.compute-node
ARG SRC_IMAGE
ARG VM_INFORMANT_VERSION=v0.1.14
# on libcgroup update, make sure to check bootstrap.sh for changes
ARG LIBCGROUP_VERSION=v2.0.3
ARG VM_INFORMANT_VERSION=v0.1.6
# Pull VM informant, to copy from later
# Pull VM informant and set up inittab
FROM neondatabase/vm-informant:$VM_INFORMANT_VERSION as informant
# Build cgroup-tools
#
# At time of writing (2023-03-14), debian bullseye has a version of cgroup-tools (technically
# libcgroup) that doesn't support cgroup v2 (version 0.41-11). Unfortunately, the vm-informant
# requires cgroup v2, so we'll build cgroup-tools ourselves.
FROM debian:bullseye-slim as libcgroup-builder
ARG LIBCGROUP_VERSION
RUN set -exu \
&& apt update \
&& apt install --no-install-recommends -y \
git \
ca-certificates \
automake \
cmake \
make \
gcc \
byacc \
flex \
libtool \
libpam0g-dev \
&& git clone --depth 1 -b $LIBCGROUP_VERSION https://github.com/libcgroup/libcgroup \
&& INSTALL_DIR="/libcgroup-install" \
&& mkdir -p "$INSTALL_DIR/bin" "$INSTALL_DIR/include" \
&& cd libcgroup \
# extracted from bootstrap.sh, with modified flags:
&& (test -d m4 || mkdir m4) \
&& autoreconf -fi \
&& rm -rf autom4te.cache \
&& CFLAGS="-O3" ./configure --prefix="$INSTALL_DIR" --sysconfdir=/etc --localstatedir=/var --enable-opaque-hierarchy="name=systemd" \
# actually build the thing...
&& make install
# Combine, starting from non-VM compute node image.
FROM $SRC_IMAGE as base
# Temporarily set user back to root so we can run adduser, set inittab
USER root
RUN adduser vm-informant --disabled-password --no-create-home
RUN set -e \
&& rm -f /etc/inittab \
&& touch /etc/inittab
RUN set -e \
&& echo "::sysinit:cgconfigparser -l /etc/cgconfig.conf -s 1664" >> /etc/inittab \
&& CONNSTR="dbname=neondb user=cloud_admin sslmode=disable" \
&& ARGS="--auto-restart --cgroup=neon-postgres --pgconnstr=\"$CONNSTR\"" \
&& echo "::respawn:su vm-informant -c '/usr/local/bin/vm-informant $ARGS'" >> /etc/inittab
&& echo "::respawn:su vm-informant -c '/usr/local/bin/vm-informant --auto-restart'" >> /etc/inittab
# Combine, starting from non-VM compute node image.
FROM $SRC_IMAGE as base
# Temporarily set user back to root so we can run adduser
USER root
RUN adduser vm-informant --disabled-password --no-create-home
USER postgres
ADD vm-cgconfig.conf /etc/cgconfig.conf
COPY --from=informant /etc/inittab /etc/inittab
COPY --from=informant /usr/bin/vm-informant /usr/local/bin/vm-informant
COPY --from=libcgroup-builder /libcgroup-install/bin/* /usr/bin/
COPY --from=libcgroup-builder /libcgroup-install/lib/* /usr/lib/
COPY --from=libcgroup-builder /libcgroup-install/sbin/* /usr/sbin/
ENTRYPOINT ["/usr/sbin/cgexec", "-g", "*:neon-postgres", "/usr/local/bin/compute_ctl"]

View File

@@ -39,6 +39,8 @@ endif
# been no changes to the files. Changing the mtime triggers an
# unnecessary rebuild of 'postgres_ffi'.
PG_CONFIGURE_OPTS += INSTALL='$(ROOT_PROJECT_DIR)/scripts/ninstall.sh -C'
PG_CONFIGURE_OPTS += CC=clang
PG_CONFIGURE_OPTS += CCX=clang++
# Choose whether we should be silent or verbose
CARGO_BUILD_FLAGS += --$(if $(filter s,$(MAKEFLAGS)),quiet,verbose)
@@ -133,11 +135,12 @@ neon-pg-ext-%: postgres-%
$(MAKE) PG_CONFIG=$(POSTGRES_INSTALL_DIR)/$*/bin/pg_config CFLAGS='$(PG_CFLAGS) $(COPT)' \
-C $(POSTGRES_INSTALL_DIR)/build/neon-test-utils-$* \
-f $(ROOT_PROJECT_DIR)/pgxn/neon_test_utils/Makefile install
+@echo "Compiling neon_utils $*"
mkdir -p $(POSTGRES_INSTALL_DIR)/build/neon-utils-$*
$(MAKE) PG_CONFIG=$(POSTGRES_INSTALL_DIR)/$*/bin/pg_config CFLAGS='$(PG_CFLAGS) $(COPT)' \
-C $(POSTGRES_INSTALL_DIR)/build/neon-utils-$* \
-f $(ROOT_PROJECT_DIR)/pgxn/neon_utils/Makefile install
.PHONY:
neon-pg-ext-walproposer:
$(MAKE) PG_CONFIG=$(POSTGRES_INSTALL_DIR)/v15/bin/pg_config CFLAGS='$(PG_CFLAGS) $(COPT)' \
-C $(POSTGRES_INSTALL_DIR)/build/neon-v15 \
-f $(ROOT_PROJECT_DIR)/pgxn/neon/Makefile install
.PHONY: neon-pg-ext-clean-%
neon-pg-ext-clean-%:
@@ -150,9 +153,6 @@ neon-pg-ext-clean-%:
$(MAKE) PG_CONFIG=$(POSTGRES_INSTALL_DIR)/$*/bin/pg_config \
-C $(POSTGRES_INSTALL_DIR)/build/neon-test-utils-$* \
-f $(ROOT_PROJECT_DIR)/pgxn/neon_test_utils/Makefile clean
$(MAKE) PG_CONFIG=$(POSTGRES_INSTALL_DIR)/$*/bin/pg_config \
-C $(POSTGRES_INSTALL_DIR)/build/neon-utils-$* \
-f $(ROOT_PROJECT_DIR)/pgxn/neon_utils/Makefile clean
.PHONY: neon-pg-ext
neon-pg-ext: \

View File

@@ -11,7 +11,6 @@ clap.workspace = true
futures.workspace = true
hyper = { workspace = true, features = ["full"] }
notify.workspace = true
num_cpus.workspace = true
opentelemetry.workspace = true
postgres.workspace = true
regex.workspace = true

View File

@@ -25,7 +25,6 @@ use anyhow::{Context, Result};
use chrono::{DateTime, Utc};
use postgres::{Client, NoTls};
use serde::{Serialize, Serializer};
use tokio_postgres;
use tracing::{info, instrument, warn};
use crate::checker::create_writability_check_data;
@@ -285,7 +284,6 @@ impl ComputeNode {
handle_role_deletions(self, &mut client)?;
handle_grants(self, &mut client)?;
create_writability_check_data(&mut client)?;
handle_extensions(&self.spec, &mut client)?;
// 'Close' connection
drop(client);
@@ -402,43 +400,4 @@ impl ComputeNode {
Ok(())
}
/// Select `pg_stat_statements` data and return it as a stringified JSON
pub async fn collect_insights(&self) -> String {
let mut result_rows: Vec<String> = Vec::new();
let connect_result = tokio_postgres::connect(self.connstr.as_str(), NoTls).await;
let (client, connection) = connect_result.unwrap();
tokio::spawn(async move {
if let Err(e) = connection.await {
eprintln!("connection error: {}", e);
}
});
let result = client
.simple_query(
"SELECT
row_to_json(pg_stat_statements)
FROM
pg_stat_statements
WHERE
userid != 'cloud_admin'::regrole::oid
ORDER BY
(mean_exec_time + mean_plan_time) DESC
LIMIT 100",
)
.await;
if let Ok(raw_rows) = result {
for message in raw_rows.iter() {
if let postgres::SimpleQueryMessage::Row(row) = message {
if let Some(json) = row.get(0) {
result_rows.push(json.to_string());
}
}
}
format!("{{\"pg_stat_statements\": [{}]}}", result_rows.join(","))
} else {
"{{\"pg_stat_statements\": []}}".to_string()
}
}
}

View File

@@ -7,7 +7,6 @@ use crate::compute::ComputeNode;
use anyhow::Result;
use hyper::service::{make_service_fn, service_fn};
use hyper::{Body, Method, Request, Response, Server, StatusCode};
use num_cpus;
use serde_json;
use tracing::{error, info};
use tracing_utils::http::OtelName;
@@ -34,13 +33,6 @@ async fn routes(req: Request<Body>, compute: &Arc<ComputeNode>) -> Response<Body
Response::new(Body::from(serde_json::to_string(&compute.metrics).unwrap()))
}
// Collect Postgres current usage insights
(&Method::GET, "/insights") => {
info!("serving /insights GET request");
let insights = compute.collect_insights().await;
Response::new(Body::from(insights))
}
(&Method::POST, "/check_writability") => {
info!("serving /check_writability POST request");
let res = crate::checker::check_writability(compute).await;
@@ -50,17 +42,6 @@ async fn routes(req: Request<Body>, compute: &Arc<ComputeNode>) -> Response<Body
}
}
(&Method::GET, "/info") => {
let num_cpus = num_cpus::get_physical();
info!("serving /info GET request. num_cpus: {}", num_cpus);
Response::new(Body::from(
serde_json::json!({
"num_cpus": num_cpus,
})
.to_string(),
))
}
// Return the `404 Not Found` for any other routes.
_ => {
let mut not_found = Response::new(Body::from("404 Not Found"));

View File

@@ -10,12 +10,12 @@ paths:
/status:
get:
tags:
- Info
- "info"
summary: Get compute node internal status
description: ""
operationId: getComputeStatus
responses:
200:
"200":
description: ComputeState
content:
application/json:
@@ -25,58 +25,27 @@ paths:
/metrics.json:
get:
tags:
- Info
- "info"
summary: Get compute node startup metrics in JSON format
description: ""
operationId: getComputeMetricsJSON
responses:
200:
"200":
description: ComputeMetrics
content:
application/json:
schema:
$ref: "#/components/schemas/ComputeMetrics"
/insights:
get:
tags:
- Info
summary: Get current compute insights in JSON format
description: |
Note, that this doesn't include any historical data
operationId: getComputeInsights
responses:
200:
description: Compute insights
content:
application/json:
schema:
$ref: "#/components/schemas/ComputeInsights"
/info:
get:
tags:
- "info"
summary: Get info about the compute Pod/VM
description: ""
operationId: getInfo
responses:
"200":
description: Info
content:
application/json:
schema:
$ref: "#/components/schemas/Info"
/check_writability:
post:
tags:
- Check
- "check"
summary: Check that we can write new data on this compute
description: ""
operationId: checkComputeWritability
responses:
200:
"200":
description: Check result
content:
text/plain:
@@ -111,15 +80,6 @@ components:
total_startup_ms:
type: integer
Info:
type: object
description: Information about VM/Pod
required:
- num_cpus
properties:
num_cpus:
type: integer
ComputeState:
type: object
required:
@@ -136,15 +96,6 @@ components:
type: string
description: Text of the error during compute startup, if any
ComputeInsights:
type: object
properties:
pg_stat_statements:
description: Contains raw output from pg_stat_statements in JSON format
type: array
items:
type: object
ComputeStatus:
type: string
enum:

View File

@@ -47,23 +47,12 @@ pub struct GenericOption {
/// declare a `trait` on it.
pub type GenericOptions = Option<Vec<GenericOption>>;
/// Escape a string for including it in a SQL literal
fn escape_literal(s: &str) -> String {
s.replace('\'', "''").replace('\\', "\\\\")
}
/// Escape a string so that it can be used in postgresql.conf.
/// Same as escape_literal, currently.
fn escape_conf_value(s: &str) -> String {
s.replace('\'', "''").replace('\\', "\\\\")
}
impl GenericOption {
/// Represent `GenericOption` as SQL statement parameter.
pub fn to_pg_option(&self) -> String {
if let Some(val) = &self.value {
match self.vartype.as_ref() {
"string" => format!("{} '{}'", self.name, escape_literal(val)),
"string" => format!("{} '{}'", self.name, val),
_ => format!("{} {}", self.name, val),
}
} else {
@@ -74,8 +63,6 @@ impl GenericOption {
/// Represent `GenericOption` as configuration option.
pub fn to_pg_setting(&self) -> String {
if let Some(val) = &self.value {
// TODO: check in the console DB that we don't have these settings
// set for any non-deleted project and drop this override.
let name = match self.name.as_str() {
"safekeepers" => "neon.safekeepers",
"wal_acceptor_reconnect" => "neon.safekeeper_reconnect_timeout",
@@ -84,7 +71,7 @@ impl GenericOption {
};
match self.vartype.as_ref() {
"string" => format!("{} = '{}'", name, escape_conf_value(val)),
"string" => format!("{} = '{}'", name, val),
_ => format!("{} = {}", name, val),
}
} else {
@@ -120,7 +107,6 @@ impl PgOptionsSerialize for GenericOptions {
.map(|op| op.to_pg_setting())
.collect::<Vec<String>>()
.join("\n")
+ "\n" // newline after last setting
} else {
"".to_string()
}

View File

@@ -515,18 +515,3 @@ pub fn handle_grants(node: &ComputeNode, client: &mut Client) -> Result<()> {
Ok(())
}
/// Create required system extensions
#[instrument(skip_all)]
pub fn handle_extensions(spec: &ComputeSpec, client: &mut Client) -> Result<()> {
if let Some(libs) = spec.cluster.settings.find("shared_preload_libraries") {
if libs.contains("pg_stat_statements") {
// Create extension only if this compute really needs it
let query = "CREATE EXTENSION IF NOT EXISTS pg_stat_statements";
info!("creating system extensions with query: {}", query);
client.simple_query(query)?;
}
}
Ok(())
}

View File

@@ -178,11 +178,6 @@
"name": "neon.pageserver_connstring",
"value": "host=127.0.0.1 port=6400",
"vartype": "string"
},
{
"name": "test.escaping",
"value": "here's a backslash \\ and a quote ' and a double-quote \" hooray",
"vartype": "string"
}
]
},

View File

@@ -28,30 +28,7 @@ mod pg_helpers_tests {
assert_eq!(
spec.cluster.settings.as_pg_settings(),
r#"fsync = off
wal_level = replica
hot_standby = on
neon.safekeepers = '127.0.0.1:6502,127.0.0.1:6503,127.0.0.1:6501'
wal_log_hints = on
log_connections = on
shared_buffers = 32768
port = 55432
max_connections = 100
max_wal_senders = 10
listen_addresses = '0.0.0.0'
wal_sender_timeout = 0
password_encryption = md5
maintenance_work_mem = 65536
max_parallel_workers = 8
max_worker_processes = 8
neon.tenant_id = 'b0554b632bd4d547a63b86c3630317e8'
max_replication_slots = 10
neon.timeline_id = '2414a61ffc94e428f14b5758fe308e13'
shared_preload_libraries = 'neon'
synchronous_standby_names = 'walproposer'
neon.pageserver_connstring = 'host=127.0.0.1 port=6400'
test.escaping = 'here''s a backslash \\ and a quote '' and a double-quote " hooray'
"#
"fsync = off\nwal_level = replica\nhot_standby = on\nneon.safekeepers = '127.0.0.1:6502,127.0.0.1:6503,127.0.0.1:6501'\nwal_log_hints = on\nlog_connections = on\nshared_buffers = 32768\nport = 55432\nmax_connections = 100\nmax_wal_senders = 10\nlisten_addresses = '0.0.0.0'\nwal_sender_timeout = 0\npassword_encryption = md5\nmaintenance_work_mem = 65536\nmax_parallel_workers = 8\nmax_worker_processes = 8\nneon.tenant_id = 'b0554b632bd4d547a63b86c3630317e8'\nmax_replication_slots = 10\nneon.timeline_id = '2414a61ffc94e428f14b5758fe308e13'\nshared_preload_libraries = 'neon'\nsynchronous_standby_names = 'walproposer'\nneon.pageserver_connstring = 'host=127.0.0.1 port=6400'"
);
}

View File

@@ -24,7 +24,6 @@ url.workspace = true
# Note: Do not directly depend on pageserver or safekeeper; use pageserver_api or safekeeper_api
# instead, so that recompile times are better.
pageserver_api.workspace = true
postgres_backend.workspace = true
safekeeper_api.workspace = true
postgres_connection.workspace = true
storage_broker.workspace = true

View File

@@ -2,8 +2,7 @@
[pageserver]
listen_pg_addr = '127.0.0.1:64000'
listen_http_addr = '127.0.0.1:9898'
pg_auth_type = 'Trust'
http_auth_type = 'Trust'
auth_type = 'Trust'
[[safekeepers]]
id = 1

View File

@@ -3,8 +3,7 @@
[pageserver]
listen_pg_addr = '127.0.0.1:64000'
listen_http_addr = '127.0.0.1:9898'
pg_auth_type = 'Trust'
http_auth_type = 'Trust'
auth_type = 'Trust'
[[safekeepers]]
id = 1

View File

@@ -17,7 +17,6 @@ use pageserver_api::{
DEFAULT_HTTP_LISTEN_ADDR as DEFAULT_PAGESERVER_HTTP_ADDR,
DEFAULT_PG_LISTEN_ADDR as DEFAULT_PAGESERVER_PG_ADDR,
};
use postgres_backend::AuthType;
use safekeeper_api::{
DEFAULT_HTTP_LISTEN_PORT as DEFAULT_SAFEKEEPER_HTTP_PORT,
DEFAULT_PG_LISTEN_PORT as DEFAULT_SAFEKEEPER_PG_PORT,
@@ -31,6 +30,7 @@ use utils::{
auth::{Claims, Scope},
id::{NodeId, TenantId, TenantTimelineId, TimelineId},
lsn::Lsn,
postgres_backend::AuthType,
project_git_version,
};
@@ -53,15 +53,14 @@ listen_addr = '{DEFAULT_BROKER_ADDR}'
id = {DEFAULT_PAGESERVER_ID}
listen_pg_addr = '{DEFAULT_PAGESERVER_PG_ADDR}'
listen_http_addr = '{DEFAULT_PAGESERVER_HTTP_ADDR}'
pg_auth_type = '{trust_auth}'
http_auth_type = '{trust_auth}'
auth_type = '{pageserver_auth_type}'
[[safekeepers]]
id = {DEFAULT_SAFEKEEPER_ID}
pg_port = {DEFAULT_SAFEKEEPER_PG_PORT}
http_port = {DEFAULT_SAFEKEEPER_HTTP_PORT}
"#,
trust_auth = AuthType::Trust,
pageserver_auth_type = AuthType::Trust,
)
}
@@ -628,7 +627,7 @@ fn handle_pg(pg_match: &ArgMatches, env: &local_env::LocalEnv) -> Result<()> {
let node = cplane.nodes.get(&(tenant_id, node_name.to_string()));
let auth_token = if matches!(env.pageserver.pg_auth_type, AuthType::NeonJWT) {
let auth_token = if matches!(env.pageserver.auth_type, AuthType::NeonJWT) {
let claims = Claims::new(Some(tenant_id), Scope::Tenant);
Some(env.generate_auth_token(&claims)?)

View File

@@ -11,10 +11,10 @@ use std::sync::Arc;
use std::time::Duration;
use anyhow::{Context, Result};
use postgres_backend::AuthType;
use utils::{
id::{TenantId, TimelineId},
lsn::Lsn,
postgres_backend::AuthType,
};
use crate::local_env::{LocalEnv, DEFAULT_PG_VERSION};
@@ -97,7 +97,7 @@ impl ComputeControlPlane {
});
node.create_pgdata()?;
node.setup_pg_conf(self.env.pageserver.pg_auth_type)?;
node.setup_pg_conf(self.env.pageserver.auth_type)?;
self.nodes
.insert((tenant_id, node.name.clone()), Arc::clone(&node));

View File

@@ -5,7 +5,6 @@
use anyhow::{bail, ensure, Context};
use postgres_backend::AuthType;
use reqwest::Url;
use serde::{Deserialize, Serialize};
use serde_with::{serde_as, DisplayFromStr};
@@ -20,6 +19,7 @@ use std::process::{Command, Stdio};
use utils::{
auth::{encode_from_key_file, Claims, Scope},
id::{NodeId, TenantId, TenantTimelineId, TimelineId},
postgres_backend::AuthType,
};
use crate::safekeeper::SafekeeperNode;
@@ -110,14 +110,12 @@ impl NeonBroker {
pub struct PageServerConf {
// node id
pub id: NodeId,
// Pageserver connection settings
pub listen_pg_addr: String,
pub listen_http_addr: String,
// auth type used for the PG and HTTP ports
pub pg_auth_type: AuthType,
pub http_auth_type: AuthType,
// used to determine which auth type is used
pub auth_type: AuthType,
// jwt auth token used for communication with pageserver
pub auth_token: String,
@@ -129,8 +127,7 @@ impl Default for PageServerConf {
id: NodeId(0),
listen_pg_addr: String::new(),
listen_http_addr: String::new(),
pg_auth_type: AuthType::Trust,
http_auth_type: AuthType::Trust,
auth_type: AuthType::Trust,
auth_token: String::new(),
}
}

View File

@@ -11,7 +11,6 @@ use anyhow::{bail, Context};
use pageserver_api::models::{
TenantConfigRequest, TenantCreateRequest, TenantInfo, TimelineCreateRequest, TimelineInfo,
};
use postgres_backend::AuthType;
use postgres_connection::{parse_host_port, PgConnectionConfig};
use reqwest::blocking::{Client, RequestBuilder, Response};
use reqwest::{IntoUrl, Method};
@@ -21,6 +20,7 @@ use utils::{
http::error::HttpErrorBody,
id::{TenantId, TimelineId},
lsn::Lsn,
postgres_backend::AuthType,
};
use crate::{background_process, local_env::LocalEnv};
@@ -82,7 +82,7 @@ impl PageServerNode {
let (host, port) = parse_host_port(&env.pageserver.listen_pg_addr)
.expect("Unable to parse listen_pg_addr");
let port = port.unwrap_or(5432);
let password = if env.pageserver.pg_auth_type == AuthType::NeonJWT {
let password = if env.pageserver.auth_type == AuthType::NeonJWT {
Some(env.pageserver.auth_token.clone())
} else {
None
@@ -106,32 +106,25 @@ impl PageServerNode {
self.env.pg_distrib_dir_raw().display()
);
let http_auth_type_param =
format!("http_auth_type='{}'", self.env.pageserver.http_auth_type);
let authg_type_param = format!("auth_type='{}'", self.env.pageserver.auth_type);
let listen_http_addr_param = format!(
"listen_http_addr='{}'",
self.env.pageserver.listen_http_addr
);
let pg_auth_type_param = format!("pg_auth_type='{}'", self.env.pageserver.pg_auth_type);
let listen_pg_addr_param =
format!("listen_pg_addr='{}'", self.env.pageserver.listen_pg_addr);
let broker_endpoint_param = format!("broker_endpoint='{}'", self.env.broker.client_url());
let mut overrides = vec![
id,
pg_distrib_dir_param,
http_auth_type_param,
pg_auth_type_param,
authg_type_param,
listen_http_addr_param,
listen_pg_addr_param,
broker_endpoint_param,
];
if self.env.pageserver.http_auth_type != AuthType::Trust
|| self.env.pageserver.pg_auth_type != AuthType::Trust
{
if self.env.pageserver.auth_type != AuthType::Trust {
overrides.push("auth_validation_public_key_path='auth_public_key.pem'".to_owned());
}
overrides
@@ -254,10 +247,7 @@ impl PageServerNode {
}
fn pageserver_env_variables(&self) -> anyhow::Result<Vec<(String, String)>> {
// FIXME: why is this tied to pageserver's auth type? Whether or not the safekeeper
// needs a token, and how to generate that token, seems independent to whether
// the pageserver requires a token in incoming requests.
Ok(if self.env.pageserver.http_auth_type != AuthType::Trust {
Ok(if self.env.pageserver.auth_type != AuthType::Trust {
// Generate a token to connect from the pageserver to a safekeeper
let token = self
.env
@@ -293,7 +283,7 @@ impl PageServerNode {
fn http_request<U: IntoUrl>(&self, method: Method, url: U) -> RequestBuilder {
let mut builder = self.http_client.request(method, url);
if self.env.pageserver.http_auth_type == AuthType::NeonJWT {
if self.env.pageserver.auth_type == AuthType::NeonJWT {
builder = builder.bearer_auth(&self.env.pageserver.auth_token)
}
builder

View File

@@ -115,10 +115,6 @@ impl SafekeeperNode {
let datadir = self.datadir_path();
let id_string = id.to_string();
// TODO: add availability_zone to the config.
// Right now we just specify any value here and use it to check metrics in tests.
let availability_zone = format!("sk-{}", id_string);
let mut args = vec![
"-D",
datadir.to_str().with_context(|| {
@@ -130,8 +126,6 @@ impl SafekeeperNode {
&listen_pg,
"--listen-http",
&listen_http,
"--availability-zone",
&availability_zone,
];
if !self.conf.sync {
args.push("--no-sync");

View File

@@ -29,41 +29,6 @@ These components should not have access to the private key and may only get toke
The key pair is generated once for an installation of compute/pageserver/safekeeper, e.g. by `neon_local init`.
There is currently no way to rotate the key without bringing down all components.
### Token format
The JWT tokens in Neon use RSA as the algorithm. Example:
Header:
```
{
"alg": "RS512", # RS256, RS384, or RS512
"typ": "JWT"
}
```
Payload:
```
{
"scope": "tenant", # "tenant", "pageserverapi", or "safekeeperdata"
"tenant_id": "5204921ff44f09de8094a1390a6a50f6",
}
```
Meanings of scope:
"tenant": Provides access to all data for a specific tenant
"pageserverapi": Provides blanket access to all tenants on the pageserver plus pageserver-wide APIs.
Should only be used e.g. for status check/tenant creation/list.
"safekeeperdata": Provides blanket access to all data on the safekeeper plus safekeeper-wide APIs.
Should only be used e.g. for status check.
Currently also used for connection from any pageserver to any safekeeper.
### CLI
CLI generates a key pair during call to `neon_local init` with the following commands:
@@ -137,12 +102,10 @@ Each compute should present a token valid for the timeline's tenant.
Pageserver also has HTTP API: some parts are per-tenant,
some parts are server-wide, these are different scopes.
Authentication can be enabled separately for the HTTP mgmt API, and
for the libpq connections from compute. The `http_auth_type` and
`pg_auth_type` configuration variables in Pageserver's config may
have one of these values:
The `auth_type` configuration variable in Pageserver's config may have
either of three values:
* `Trust` removes all authentication.
* `Trust` removes all authentication. The outdated `MD5` value does likewise
* `NeonJWT` enables JWT validation.
Tokens are validated using the public key which lies in a PEM file
specified in the `auth_validation_public_key_path` config.

View File

@@ -129,12 +129,13 @@ Run `poetry shell` to activate the virtual environment.
Alternatively, use `poetry run` to run a single command in the venv, e.g. `poetry run pytest`.
### Obligatory checks
We force code formatting via `black`, `ruff`, and type hints via `mypy`.
We force code formatting via `black`, `isort` and type hints via `mypy`.
Run the following commands in the repository's root (next to `pyproject.toml`):
```bash
poetry run isort . # Imports are reformatted
poetry run black . # All code is reformatted
poetry run ruff . # Python linter
poetry run flake8 . # Python linter
poetry run mypy . # Ensure there are no typing errors
```

View File

@@ -1,26 +0,0 @@
[package]
name = "postgres_backend"
version = "0.1.0"
edition.workspace = true
license.workspace = true
[dependencies]
async-trait.workspace = true
anyhow.workspace = true
bytes.workspace = true
futures.workspace = true
rustls.workspace = true
serde.workspace = true
thiserror.workspace = true
tokio.workspace = true
tokio-rustls.workspace = true
tracing.workspace = true
pq_proto.workspace = true
workspace_hack.workspace = true
[dev-dependencies]
once_cell.workspace = true
rustls-pemfile.workspace = true
tokio-postgres.workspace = true
tokio-postgres-rustls.workspace = true

View File

@@ -1,931 +0,0 @@
//! Server-side asynchronous Postgres connection, as limited as we need.
//! To use, create PostgresBackend and run() it, passing the Handler
//! implementation determining how to process the queries. Currently its API
//! is rather narrow, but we can extend it once required.
use anyhow::Context;
use bytes::Bytes;
use futures::pin_mut;
use serde::{Deserialize, Serialize};
use std::io::ErrorKind;
use std::net::SocketAddr;
use std::pin::Pin;
use std::sync::Arc;
use std::task::{ready, Poll};
use std::{fmt, io};
use std::{future::Future, str::FromStr};
use tokio::io::{AsyncRead, AsyncWrite};
use tokio_rustls::TlsAcceptor;
use tracing::{debug, error, info, trace};
use pq_proto::framed::{ConnectionError, Framed, FramedReader, FramedWriter};
use pq_proto::{
BeMessage, FeMessage, FeStartupPacket, ProtocolError, SQLSTATE_INTERNAL_ERROR,
SQLSTATE_SUCCESSFUL_COMPLETION,
};
/// An error, occurred during query processing:
/// either during the connection ([`ConnectionError`]) or before/after it.
#[derive(thiserror::Error, Debug)]
pub enum QueryError {
/// The connection was lost while processing the query.
#[error(transparent)]
Disconnected(#[from] ConnectionError),
/// Some other error
#[error(transparent)]
Other(#[from] anyhow::Error),
}
impl From<io::Error> for QueryError {
fn from(e: io::Error) -> Self {
Self::Disconnected(ConnectionError::Io(e))
}
}
impl QueryError {
pub fn pg_error_code(&self) -> &'static [u8; 5] {
match self {
Self::Disconnected(_) => b"08006", // connection failure
Self::Other(_) => SQLSTATE_INTERNAL_ERROR, // internal error
}
}
}
pub fn is_expected_io_error(e: &io::Error) -> bool {
use io::ErrorKind::*;
matches!(
e.kind(),
ConnectionRefused | ConnectionAborted | ConnectionReset
)
}
#[async_trait::async_trait]
pub trait Handler<IO> {
/// Handle single query.
/// postgres_backend will issue ReadyForQuery after calling this (this
/// might be not what we want after CopyData streaming, but currently we don't
/// care). It will also flush out the output buffer.
async fn process_query(
&mut self,
pgb: &mut PostgresBackend<IO>,
query_string: &str,
) -> Result<(), QueryError>;
/// Called on startup packet receival, allows to process params.
///
/// If Ok(false) is returned postgres_backend will skip auth -- that is needed for new users
/// creation is the proxy code. That is quite hacky and ad-hoc solution, may be we could allow
/// to override whole init logic in implementations.
fn startup(
&mut self,
_pgb: &mut PostgresBackend<IO>,
_sm: &FeStartupPacket,
) -> Result<(), QueryError> {
Ok(())
}
/// Check auth jwt
fn check_auth_jwt(
&mut self,
_pgb: &mut PostgresBackend<IO>,
_jwt_response: &[u8],
) -> Result<(), QueryError> {
Err(QueryError::Other(anyhow::anyhow!("JWT auth failed")))
}
}
/// PostgresBackend protocol state.
/// XXX: The order of the constructors matters.
#[derive(Clone, Copy, PartialEq, Eq, PartialOrd)]
pub enum ProtoState {
/// Nothing happened yet.
Initialization,
/// Encryption handshake is done; waiting for encrypted Startup message.
Encrypted,
/// Waiting for password (auth token).
Authentication,
/// Performed handshake and auth, ReadyForQuery is issued.
Established,
Closed,
}
#[derive(Clone, Copy)]
pub enum ProcessMsgResult {
Continue,
Break,
}
/// Either plain TCP stream or encrypted one, implementing AsyncRead + AsyncWrite.
pub enum MaybeTlsStream<IO> {
Unencrypted(IO),
Tls(Box<tokio_rustls::server::TlsStream<IO>>),
}
impl<IO: AsyncRead + AsyncWrite + Unpin> AsyncWrite for MaybeTlsStream<IO> {
fn poll_write(
self: Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
buf: &[u8],
) -> Poll<io::Result<usize>> {
match self.get_mut() {
Self::Unencrypted(stream) => Pin::new(stream).poll_write(cx, buf),
Self::Tls(stream) => Pin::new(stream).poll_write(cx, buf),
}
}
fn poll_flush(self: Pin<&mut Self>, cx: &mut std::task::Context<'_>) -> Poll<io::Result<()>> {
match self.get_mut() {
Self::Unencrypted(stream) => Pin::new(stream).poll_flush(cx),
Self::Tls(stream) => Pin::new(stream).poll_flush(cx),
}
}
fn poll_shutdown(
self: Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
) -> Poll<io::Result<()>> {
match self.get_mut() {
Self::Unencrypted(stream) => Pin::new(stream).poll_shutdown(cx),
Self::Tls(stream) => Pin::new(stream).poll_shutdown(cx),
}
}
}
impl<IO: AsyncRead + AsyncWrite + Unpin> AsyncRead for MaybeTlsStream<IO> {
fn poll_read(
self: Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
buf: &mut tokio::io::ReadBuf<'_>,
) -> Poll<io::Result<()>> {
match self.get_mut() {
Self::Unencrypted(stream) => Pin::new(stream).poll_read(cx, buf),
Self::Tls(stream) => Pin::new(stream).poll_read(cx, buf),
}
}
}
#[derive(Debug, PartialEq, Eq, Clone, Copy, Serialize, Deserialize)]
pub enum AuthType {
Trust,
// This mimics postgres's AuthenticationCleartextPassword but instead of password expects JWT
NeonJWT,
}
impl FromStr for AuthType {
type Err = anyhow::Error;
fn from_str(s: &str) -> Result<Self, Self::Err> {
match s {
"Trust" => Ok(Self::Trust),
"NeonJWT" => Ok(Self::NeonJWT),
_ => anyhow::bail!("invalid value \"{s}\" for auth type"),
}
}
}
impl fmt::Display for AuthType {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.write_str(match self {
AuthType::Trust => "Trust",
AuthType::NeonJWT => "NeonJWT",
})
}
}
/// Either full duplex Framed or write only half; the latter is left in
/// PostgresBackend after call to `split`. In principle we could always store a
/// pair of splitted handles, but that would force to to pay splitting price
/// (Arc and kinda mutex inside polling) for all uses (e.g. pageserver).
enum MaybeWriteOnly<IO> {
Full(Framed<MaybeTlsStream<IO>>),
WriteOnly(FramedWriter<MaybeTlsStream<IO>>),
Broken, // temporary value palmed off during the split
}
impl<IO: AsyncRead + AsyncWrite + Unpin> MaybeWriteOnly<IO> {
async fn read_startup_message(&mut self) -> Result<Option<FeStartupPacket>, ConnectionError> {
match self {
MaybeWriteOnly::Full(framed) => framed.read_startup_message().await,
MaybeWriteOnly::WriteOnly(_) => {
Err(io::Error::new(ErrorKind::Other, "reading from write only half").into())
}
MaybeWriteOnly::Broken => panic!("IO on invalid MaybeWriteOnly"),
}
}
async fn read_message(&mut self) -> Result<Option<FeMessage>, ConnectionError> {
match self {
MaybeWriteOnly::Full(framed) => framed.read_message().await,
MaybeWriteOnly::WriteOnly(_) => {
Err(io::Error::new(ErrorKind::Other, "reading from write only half").into())
}
MaybeWriteOnly::Broken => panic!("IO on invalid MaybeWriteOnly"),
}
}
fn write_message_noflush(&mut self, msg: &BeMessage<'_>) -> Result<(), ProtocolError> {
match self {
MaybeWriteOnly::Full(framed) => framed.write_message(msg),
MaybeWriteOnly::WriteOnly(framed_writer) => framed_writer.write_message_noflush(msg),
MaybeWriteOnly::Broken => panic!("IO on invalid MaybeWriteOnly"),
}
}
async fn flush(&mut self) -> io::Result<()> {
match self {
MaybeWriteOnly::Full(framed) => framed.flush().await,
MaybeWriteOnly::WriteOnly(framed_writer) => framed_writer.flush().await,
MaybeWriteOnly::Broken => panic!("IO on invalid MaybeWriteOnly"),
}
}
async fn shutdown(&mut self) -> io::Result<()> {
match self {
MaybeWriteOnly::Full(framed) => framed.shutdown().await,
MaybeWriteOnly::WriteOnly(framed_writer) => framed_writer.shutdown().await,
MaybeWriteOnly::Broken => panic!("IO on invalid MaybeWriteOnly"),
}
}
}
pub struct PostgresBackend<IO> {
framed: MaybeWriteOnly<IO>,
pub state: ProtoState,
auth_type: AuthType,
peer_addr: SocketAddr,
pub tls_config: Option<Arc<rustls::ServerConfig>>,
}
pub type PostgresBackendTCP = PostgresBackend<tokio::net::TcpStream>;
pub fn query_from_cstring(query_string: Bytes) -> Vec<u8> {
let mut query_string = query_string.to_vec();
if let Some(ch) = query_string.last() {
if *ch == 0 {
query_string.pop();
}
}
query_string
}
/// Cast a byte slice to a string slice, dropping null terminator if there's one.
fn cstr_to_str(bytes: &[u8]) -> anyhow::Result<&str> {
let without_null = bytes.strip_suffix(&[0]).unwrap_or(bytes);
std::str::from_utf8(without_null).map_err(|e| e.into())
}
impl PostgresBackend<tokio::net::TcpStream> {
pub fn new(
socket: tokio::net::TcpStream,
auth_type: AuthType,
tls_config: Option<Arc<rustls::ServerConfig>>,
) -> io::Result<Self> {
let peer_addr = socket.peer_addr()?;
let stream = MaybeTlsStream::Unencrypted(socket);
Ok(Self {
framed: MaybeWriteOnly::Full(Framed::new(stream)),
state: ProtoState::Initialization,
auth_type,
tls_config,
peer_addr,
})
}
}
impl<IO: AsyncRead + AsyncWrite + Unpin> PostgresBackend<IO> {
pub fn new_from_io(
socket: IO,
peer_addr: SocketAddr,
auth_type: AuthType,
tls_config: Option<Arc<rustls::ServerConfig>>,
) -> io::Result<Self> {
let stream = MaybeTlsStream::Unencrypted(socket);
Ok(Self {
framed: MaybeWriteOnly::Full(Framed::new(stream)),
state: ProtoState::Initialization,
auth_type,
tls_config,
peer_addr,
})
}
pub fn get_peer_addr(&self) -> &SocketAddr {
&self.peer_addr
}
/// Read full message or return None if connection is cleanly closed with no
/// unprocessed data.
pub async fn read_message(&mut self) -> Result<Option<FeMessage>, ConnectionError> {
if let ProtoState::Closed = self.state {
Ok(None)
} else {
let m = self.framed.read_message().await?;
trace!("read msg {:?}", m);
Ok(m)
}
}
/// Write message into internal output buffer, doesn't flush it. Technically
/// error type can be only ProtocolError here (if, unlikely, serialization
/// fails), but callers typically wrap it anyway.
pub fn write_message_noflush(
&mut self,
message: &BeMessage<'_>,
) -> Result<&mut Self, ConnectionError> {
self.framed.write_message_noflush(message)?;
trace!("wrote msg {:?}", message);
Ok(self)
}
/// Flush output buffer into the socket.
pub async fn flush(&mut self) -> io::Result<()> {
self.framed.flush().await
}
/// Polling version of `flush()`, saves the caller need to pin.
pub fn poll_flush(
&mut self,
cx: &mut std::task::Context<'_>,
) -> Poll<Result<(), std::io::Error>> {
let flush_fut = self.flush();
pin_mut!(flush_fut);
flush_fut.poll(cx)
}
/// Write message into internal output buffer and flush it to the stream.
pub async fn write_message(
&mut self,
message: &BeMessage<'_>,
) -> Result<&mut Self, ConnectionError> {
self.write_message_noflush(message)?;
self.flush().await?;
Ok(self)
}
/// Returns an AsyncWrite implementation that wraps all the data written
/// to it in CopyData messages, and writes them to the connection
///
/// The caller is responsible for sending CopyOutResponse and CopyDone messages.
pub fn copyout_writer(&mut self) -> CopyDataWriter<IO> {
CopyDataWriter { pgb: self }
}
/// Wrapper for run_message_loop() that shuts down socket when we are done
pub async fn run<F, S>(
mut self,
handler: &mut impl Handler<IO>,
shutdown_watcher: F,
) -> Result<(), QueryError>
where
F: Fn() -> S,
S: Future,
{
let ret = self.run_message_loop(handler, shutdown_watcher).await;
// socket might be already closed, e.g. if previously received error,
// so ignore result.
self.framed.shutdown().await.ok();
ret
}
async fn run_message_loop<F, S>(
&mut self,
handler: &mut impl Handler<IO>,
shutdown_watcher: F,
) -> Result<(), QueryError>
where
F: Fn() -> S,
S: Future,
{
trace!("postgres backend to {:?} started", self.peer_addr);
tokio::select!(
biased;
_ = shutdown_watcher() => {
// We were requested to shut down.
tracing::info!("shutdown request received during handshake");
return Ok(())
},
result = self.handshake(handler) => {
// Handshake complete.
result?;
if self.state == ProtoState::Closed {
return Ok(()); // EOF during handshake
}
}
);
// Authentication completed
let mut query_string = Bytes::new();
while let Some(msg) = tokio::select!(
biased;
_ = shutdown_watcher() => {
// We were requested to shut down.
tracing::info!("shutdown request received in run_message_loop");
Ok(None)
},
msg = self.read_message() => { msg },
)? {
trace!("got message {:?}", msg);
let result = self.process_message(handler, msg, &mut query_string).await;
self.flush().await?;
match result? {
ProcessMsgResult::Continue => {
self.flush().await?;
continue;
}
ProcessMsgResult::Break => break,
}
}
trace!("postgres backend to {:?} exited", self.peer_addr);
Ok(())
}
/// Try to upgrade MaybeTlsStream into actual TLS one, performing handshake.
async fn tls_upgrade(
src: MaybeTlsStream<IO>,
tls_config: Arc<rustls::ServerConfig>,
) -> anyhow::Result<MaybeTlsStream<IO>> {
match src {
MaybeTlsStream::Unencrypted(s) => {
let acceptor = TlsAcceptor::from(tls_config);
let tls_stream = acceptor.accept(s).await?;
Ok(MaybeTlsStream::Tls(Box::new(tls_stream)))
}
MaybeTlsStream::Tls(_) => {
anyhow::bail!("TLS already started");
}
}
}
async fn start_tls(&mut self) -> anyhow::Result<()> {
// temporary replace stream with fake to cook TLS one, Indiana Jones style
match std::mem::replace(&mut self.framed, MaybeWriteOnly::Broken) {
MaybeWriteOnly::Full(framed) => {
let tls_config = self
.tls_config
.as_ref()
.context("start_tls called without conf")?
.clone();
let tls_framed = framed
.map_stream(|s| PostgresBackend::tls_upgrade(s, tls_config))
.await?;
// push back ready TLS stream
self.framed = MaybeWriteOnly::Full(tls_framed);
Ok(())
}
MaybeWriteOnly::WriteOnly(_) => {
anyhow::bail!("TLS upgrade attempt in split state")
}
MaybeWriteOnly::Broken => panic!("TLS upgrade on framed in invalid state"),
}
}
/// Split off owned read part from which messages can be read in different
/// task/thread.
pub fn split(&mut self) -> anyhow::Result<PostgresBackendReader<IO>> {
// temporary replace stream with fake to cook split one, Indiana Jones style
match std::mem::replace(&mut self.framed, MaybeWriteOnly::Broken) {
MaybeWriteOnly::Full(framed) => {
let (reader, writer) = framed.split();
self.framed = MaybeWriteOnly::WriteOnly(writer);
Ok(PostgresBackendReader(reader))
}
MaybeWriteOnly::WriteOnly(_) => {
anyhow::bail!("PostgresBackend is already split")
}
MaybeWriteOnly::Broken => panic!("split on framed in invalid state"),
}
}
/// Join read part back.
pub fn unsplit(&mut self, reader: PostgresBackendReader<IO>) -> anyhow::Result<()> {
// temporary replace stream with fake to cook joined one, Indiana Jones style
match std::mem::replace(&mut self.framed, MaybeWriteOnly::Broken) {
MaybeWriteOnly::Full(_) => {
anyhow::bail!("PostgresBackend is not split")
}
MaybeWriteOnly::WriteOnly(writer) => {
let joined = Framed::unsplit(reader.0, writer);
self.framed = MaybeWriteOnly::Full(joined);
Ok(())
}
MaybeWriteOnly::Broken => panic!("unsplit on framed in invalid state"),
}
}
/// Perform handshake with the client, transitioning to Established.
/// In case of EOF during handshake logs this, sets state to Closed and returns Ok(()).
async fn handshake(&mut self, handler: &mut impl Handler<IO>) -> Result<(), QueryError> {
while self.state < ProtoState::Authentication {
match self.framed.read_startup_message().await? {
Some(msg) => {
self.process_startup_message(handler, msg).await?;
}
None => {
trace!(
"postgres backend to {:?} received EOF during handshake",
self.peer_addr
);
self.state = ProtoState::Closed;
return Ok(());
}
}
}
// Perform auth, if needed.
if self.state == ProtoState::Authentication {
match self.framed.read_message().await? {
Some(FeMessage::PasswordMessage(m)) => {
assert!(self.auth_type == AuthType::NeonJWT);
let (_, jwt_response) = m.split_last().context("protocol violation")?;
if let Err(e) = handler.check_auth_jwt(self, jwt_response) {
self.write_message_noflush(&BeMessage::ErrorResponse(
&e.to_string(),
Some(e.pg_error_code()),
))?;
return Err(e);
}
self.write_message_noflush(&BeMessage::AuthenticationOk)?
.write_message_noflush(&BeMessage::CLIENT_ENCODING)?
.write_message(&BeMessage::ReadyForQuery)
.await?;
self.state = ProtoState::Established;
}
Some(m) => {
return Err(QueryError::Other(anyhow::anyhow!(
"Unexpected message {:?} while waiting for handshake",
m
)));
}
None => {
trace!(
"postgres backend to {:?} received EOF during auth",
self.peer_addr
);
self.state = ProtoState::Closed;
return Ok(());
}
}
}
Ok(())
}
/// Process startup packet:
/// - transition to Established if auth type is trust
/// - transition to Authentication if auth type is NeonJWT.
/// - or perform TLS handshake -- then need to call this again to receive
/// actual startup packet.
async fn process_startup_message(
&mut self,
handler: &mut impl Handler<IO>,
msg: FeStartupPacket,
) -> Result<(), QueryError> {
assert!(self.state < ProtoState::Authentication);
let have_tls = self.tls_config.is_some();
match msg {
FeStartupPacket::SslRequest => {
debug!("SSL requested");
self.write_message(&BeMessage::EncryptionResponse(have_tls))
.await?;
if have_tls {
self.start_tls().await?;
self.state = ProtoState::Encrypted;
}
}
FeStartupPacket::GssEncRequest => {
debug!("GSS requested");
self.write_message(&BeMessage::EncryptionResponse(false))
.await?;
}
FeStartupPacket::StartupMessage { .. } => {
if have_tls && !matches!(self.state, ProtoState::Encrypted) {
self.write_message(&BeMessage::ErrorResponse("must connect with TLS", None))
.await?;
return Err(QueryError::Other(anyhow::anyhow!(
"client did not connect with TLS"
)));
}
// NB: startup() may change self.auth_type -- we are using that in proxy code
// to bypass auth for new users.
handler.startup(self, &msg)?;
match self.auth_type {
AuthType::Trust => {
self.write_message_noflush(&BeMessage::AuthenticationOk)?
.write_message_noflush(&BeMessage::CLIENT_ENCODING)?
.write_message_noflush(&BeMessage::INTEGER_DATETIMES)?
// The async python driver requires a valid server_version
.write_message_noflush(&BeMessage::server_version("14.1"))?
.write_message(&BeMessage::ReadyForQuery)
.await?;
self.state = ProtoState::Established;
}
AuthType::NeonJWT => {
self.write_message(&BeMessage::AuthenticationCleartextPassword)
.await?;
self.state = ProtoState::Authentication;
}
}
}
FeStartupPacket::CancelRequest { .. } => {
return Err(QueryError::Other(anyhow::anyhow!(
"Unexpected CancelRequest message during handshake"
)));
}
}
Ok(())
}
async fn process_message(
&mut self,
handler: &mut impl Handler<IO>,
msg: FeMessage,
unnamed_query_string: &mut Bytes,
) -> Result<ProcessMsgResult, QueryError> {
// Allow only startup and password messages during auth. Otherwise client would be able to bypass auth
// TODO: change that to proper top-level match of protocol state with separate message handling for each state
assert!(self.state == ProtoState::Established);
match msg {
FeMessage::Query(body) => {
// remove null terminator
let query_string = cstr_to_str(&body)?;
trace!("got query {query_string:?}");
if let Err(e) = handler.process_query(self, query_string).await {
log_query_error(query_string, &e);
let short_error = short_error(&e);
self.write_message_noflush(&BeMessage::ErrorResponse(
&short_error,
Some(e.pg_error_code()),
))?;
}
self.write_message_noflush(&BeMessage::ReadyForQuery)?;
}
FeMessage::Parse(m) => {
*unnamed_query_string = m.query_string;
self.write_message_noflush(&BeMessage::ParseComplete)?;
}
FeMessage::Describe(_) => {
self.write_message_noflush(&BeMessage::ParameterDescription)?
.write_message_noflush(&BeMessage::NoData)?;
}
FeMessage::Bind(_) => {
self.write_message_noflush(&BeMessage::BindComplete)?;
}
FeMessage::Close(_) => {
self.write_message_noflush(&BeMessage::CloseComplete)?;
}
FeMessage::Execute(_) => {
let query_string = cstr_to_str(unnamed_query_string)?;
trace!("got execute {query_string:?}");
if let Err(e) = handler.process_query(self, query_string).await {
log_query_error(query_string, &e);
self.write_message_noflush(&BeMessage::ErrorResponse(
&e.to_string(),
Some(e.pg_error_code()),
))?;
}
// NOTE there is no ReadyForQuery message. This handler is used
// for basebackup and it uses CopyOut which doesn't require
// ReadyForQuery message and backend just switches back to
// processing mode after sending CopyDone or ErrorResponse.
}
FeMessage::Sync => {
self.write_message_noflush(&BeMessage::ReadyForQuery)?;
}
FeMessage::Terminate => {
return Ok(ProcessMsgResult::Break);
}
// We prefer explicit pattern matching to wildcards, because
// this helps us spot the places where new variants are missing
FeMessage::CopyData(_)
| FeMessage::CopyDone
| FeMessage::CopyFail
| FeMessage::PasswordMessage(_) => {
return Err(QueryError::Other(anyhow::anyhow!(
"unexpected message type: {msg:?}",
)));
}
}
Ok(ProcessMsgResult::Continue)
}
/// Log as info/error result of handling COPY stream and send back
/// ErrorResponse if that makes sense. Shutdown the stream if we got
/// Terminate. TODO: transition into waiting for Sync msg if we initiate the
/// close.
pub async fn handle_copy_stream_end(&mut self, end: CopyStreamHandlerEnd) {
use CopyStreamHandlerEnd::*;
let expected_end = match &end {
ServerInitiated(_) | CopyDone | CopyFail | Terminate | EOF => true,
CopyStreamHandlerEnd::Disconnected(ConnectionError::Io(io_error))
if is_expected_io_error(io_error) =>
{
true
}
_ => false,
};
if expected_end {
info!("terminated: {:#}", end);
} else {
error!("terminated: {:?}", end);
}
// Note: no current usages ever send this
if let CopyDone = &end {
if let Err(e) = self.write_message(&BeMessage::CopyDone).await {
error!("failed to send CopyDone: {}", e);
}
}
if let Terminate = &end {
self.state = ProtoState::Closed;
}
let err_to_send_and_errcode = match &end {
ServerInitiated(_) => Some((end.to_string(), SQLSTATE_SUCCESSFUL_COMPLETION)),
Other(_) => Some((end.to_string(), SQLSTATE_INTERNAL_ERROR)),
// Note: CopyFail in duplex copy is somewhat unexpected (at least to
// PG walsender; evidently and per my docs reading client should
// finish it with CopyDone). It is not a problem to recover from it
// finishing the stream in both directions like we do, but note that
// sync rust-postgres client (which we don't use anymore) hangs if
// socket is not closed here.
// https://github.com/sfackler/rust-postgres/issues/755
// https://github.com/neondatabase/neon/issues/935
//
// Currently, the version of tokio_postgres replication patch we use
// sends this when it closes the stream (e.g. pageserver decided to
// switch conn to another safekeeper and client gets dropped).
// Moreover, seems like 'connection' task errors with 'unexpected
// message from server' when it receives ErrorResponse (anything but
// CopyData/CopyDone) back.
CopyFail => Some((end.to_string(), SQLSTATE_SUCCESSFUL_COMPLETION)),
_ => None,
};
if let Some((err, errcode)) = err_to_send_and_errcode {
if let Err(ee) = self
.write_message(&BeMessage::ErrorResponse(&err, Some(errcode)))
.await
{
error!("failed to send ErrorResponse: {}", ee);
}
}
}
}
pub struct PostgresBackendReader<IO>(FramedReader<MaybeTlsStream<IO>>);
impl<IO: AsyncRead + AsyncWrite + Unpin> PostgresBackendReader<IO> {
/// Read full message or return None if connection is cleanly closed with no
/// unprocessed data.
pub async fn read_message(&mut self) -> Result<Option<FeMessage>, ConnectionError> {
let m = self.0.read_message().await?;
trace!("read msg {:?}", m);
Ok(m)
}
/// Get CopyData contents of the next message in COPY stream or error
/// closing it. The error type is wider than actual errors which can happen
/// here -- it includes 'Other' and 'ServerInitiated', but that's ok for
/// current callers.
pub async fn read_copy_message(&mut self) -> Result<Bytes, CopyStreamHandlerEnd> {
match self.read_message().await? {
Some(msg) => match msg {
FeMessage::CopyData(m) => Ok(m),
FeMessage::CopyDone => Err(CopyStreamHandlerEnd::CopyDone),
FeMessage::CopyFail => Err(CopyStreamHandlerEnd::CopyFail),
FeMessage::Terminate => Err(CopyStreamHandlerEnd::Terminate),
_ => Err(CopyStreamHandlerEnd::from(ConnectionError::Protocol(
ProtocolError::Protocol(format!("unexpected message in COPY stream {:?}", msg)),
))),
},
None => Err(CopyStreamHandlerEnd::EOF),
}
}
}
///
/// A futures::AsyncWrite implementation that wraps all data written to it in CopyData
/// messages.
///
pub struct CopyDataWriter<'a, IO> {
pgb: &'a mut PostgresBackend<IO>,
}
impl<'a, IO: AsyncRead + AsyncWrite + Unpin> AsyncWrite for CopyDataWriter<'a, IO> {
fn poll_write(
self: Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
buf: &[u8],
) -> Poll<Result<usize, std::io::Error>> {
let this = self.get_mut();
// It's not strictly required to flush between each message, but makes it easier
// to view in wireshark, and usually the messages that the callers write are
// decently-sized anyway.
if let Err(err) = ready!(this.pgb.poll_flush(cx)) {
return Poll::Ready(Err(err));
}
// CopyData
// XXX: if the input is large, we should split it into multiple messages.
// Not sure what the threshold should be, but the ultimate hard limit is that
// the length cannot exceed u32.
this.pgb
.write_message_noflush(&BeMessage::CopyData(buf))
// write_message only writes to the buffer, so it can fail iff the
// message is invaid, but CopyData can't be invalid.
.map_err(|_| io::Error::new(ErrorKind::Other, "failed to serialize CopyData"))?;
Poll::Ready(Ok(buf.len()))
}
fn poll_flush(
self: Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
) -> Poll<Result<(), std::io::Error>> {
let this = self.get_mut();
this.pgb.poll_flush(cx)
}
fn poll_shutdown(
self: Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
) -> Poll<Result<(), std::io::Error>> {
let this = self.get_mut();
this.pgb.poll_flush(cx)
}
}
pub fn short_error(e: &QueryError) -> String {
match e {
QueryError::Disconnected(connection_error) => connection_error.to_string(),
QueryError::Other(e) => format!("{e:#}"),
}
}
fn log_query_error(query: &str, e: &QueryError) {
match e {
QueryError::Disconnected(ConnectionError::Io(io_error)) => {
if is_expected_io_error(io_error) {
info!("query handler for '{query}' failed with expected io error: {io_error}");
} else {
error!("query handler for '{query}' failed with io error: {io_error}");
}
}
QueryError::Disconnected(other_connection_error) => {
error!("query handler for '{query}' failed with connection error: {other_connection_error:?}")
}
QueryError::Other(e) => {
error!("query handler for '{query}' failed: {e:?}");
}
}
}
/// Something finishing handling of COPY stream, see handle_copy_stream_end.
/// This is not always a real error, but it allows to use ? and thiserror impls.
#[derive(thiserror::Error, Debug)]
pub enum CopyStreamHandlerEnd {
/// Handler initiates the end of streaming.
#[error("{0}")]
ServerInitiated(String),
#[error("received CopyDone")]
CopyDone,
#[error("received CopyFail")]
CopyFail,
#[error("received Terminate")]
Terminate,
#[error("EOF on COPY stream")]
EOF,
/// The connection was lost
#[error(transparent)]
Disconnected(#[from] ConnectionError),
/// Some other error
#[error(transparent)]
Other(#[from] anyhow::Error),
}

View File

@@ -1,140 +0,0 @@
/// Test postgres_backend_async with tokio_postgres
use once_cell::sync::Lazy;
use postgres_backend::{AuthType, Handler, PostgresBackend, QueryError};
use pq_proto::{BeMessage, RowDescriptor};
use std::io::Cursor;
use std::{future, sync::Arc};
use tokio::io::{AsyncRead, AsyncWrite};
use tokio::net::{TcpListener, TcpStream};
use tokio_postgres::config::SslMode;
use tokio_postgres::tls::MakeTlsConnect;
use tokio_postgres::{Config, NoTls, SimpleQueryMessage};
use tokio_postgres_rustls::MakeRustlsConnect;
// generate client, server test streams
async fn make_tcp_pair() -> (TcpStream, TcpStream) {
let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
let addr = listener.local_addr().unwrap();
let client_stream = TcpStream::connect(addr).await.unwrap();
let (server_stream, _) = listener.accept().await.unwrap();
(client_stream, server_stream)
}
struct TestHandler {}
#[async_trait::async_trait]
impl<IO: AsyncRead + AsyncWrite + Unpin + Send> Handler<IO> for TestHandler {
// return single col 'hey' for any query
async fn process_query(
&mut self,
pgb: &mut PostgresBackend<IO>,
_query_string: &str,
) -> Result<(), QueryError> {
pgb.write_message_noflush(&BeMessage::RowDescription(&[RowDescriptor::text_col(
b"hey",
)]))?
.write_message_noflush(&BeMessage::DataRow(&[Some("hey".as_bytes())]))?
.write_message_noflush(&BeMessage::CommandComplete(b"SELECT 1"))?;
Ok(())
}
}
// test that basic select works
#[tokio::test]
async fn simple_select() {
let (client_sock, server_sock) = make_tcp_pair().await;
// create and run pgbackend
let pgbackend =
PostgresBackend::new(server_sock, AuthType::Trust, None).expect("pgbackend creation");
tokio::spawn(async move {
let mut handler = TestHandler {};
pgbackend.run(&mut handler, future::pending::<()>).await
});
let conf = Config::new();
let (client, connection) = conf.connect_raw(client_sock, NoTls).await.expect("connect");
// The connection object performs the actual communication with the database,
// so spawn it off to run on its own.
tokio::spawn(async move {
if let Err(e) = connection.await {
eprintln!("connection error: {}", e);
}
});
let first_val = &(client.simple_query("SELECT 42;").await.expect("select"))[0];
if let SimpleQueryMessage::Row(row) = first_val {
let first_col = row.get(0).expect("first column");
assert_eq!(first_col, "hey");
} else {
panic!("expected SimpleQueryMessage::Row");
}
}
static KEY: Lazy<rustls::PrivateKey> = Lazy::new(|| {
let mut cursor = Cursor::new(include_bytes!("key.pem"));
rustls::PrivateKey(rustls_pemfile::rsa_private_keys(&mut cursor).unwrap()[0].clone())
});
static CERT: Lazy<rustls::Certificate> = Lazy::new(|| {
let mut cursor = Cursor::new(include_bytes!("cert.pem"));
rustls::Certificate(rustls_pemfile::certs(&mut cursor).unwrap()[0].clone())
});
// test that basic select with ssl works
#[tokio::test]
async fn simple_select_ssl() {
let (client_sock, server_sock) = make_tcp_pair().await;
let server_cfg = rustls::ServerConfig::builder()
.with_safe_defaults()
.with_no_client_auth()
.with_single_cert(vec![CERT.clone()], KEY.clone())
.unwrap();
let tls_config = Some(Arc::new(server_cfg));
let pgbackend =
PostgresBackend::new(server_sock, AuthType::Trust, tls_config).expect("pgbackend creation");
tokio::spawn(async move {
let mut handler = TestHandler {};
pgbackend.run(&mut handler, future::pending::<()>).await
});
let client_cfg = rustls::ClientConfig::builder()
.with_safe_defaults()
.with_root_certificates({
let mut store = rustls::RootCertStore::empty();
store.add(&CERT).unwrap();
store
})
.with_no_client_auth();
let mut make_tls_connect = tokio_postgres_rustls::MakeRustlsConnect::new(client_cfg);
let tls_connect = <MakeRustlsConnect as MakeTlsConnect<TcpStream>>::make_tls_connect(
&mut make_tls_connect,
"localhost",
)
.expect("make_tls_connect");
let mut conf = Config::new();
conf.ssl_mode(SslMode::Require);
let (client, connection) = conf
.connect_raw(client_sock, tls_connect)
.await
.expect("connect");
// The connection object performs the actual communication with the database,
// so spawn it off to run on its own.
tokio::spawn(async move {
if let Err(e) = connection.await {
eprintln!("connection error: {}", e);
}
});
let first_val = &(client.simple_query("SELECT 42;").await.expect("select"))[0];
if let SimpleQueryMessage::Row(row) = first_val {
let first_col = row.get(0).expect("first column");
assert_eq!(first_col, "hey");
} else {
panic!("expected SimpleQueryMessage::Row");
}
}

View File

@@ -63,7 +63,10 @@ fn main() -> anyhow::Result<()> {
pg_install_dir_versioned = cwd.join("..").join("..").join(pg_install_dir_versioned);
}
let pg_config_bin = pg_install_dir_versioned.join("bin").join("pg_config");
let pg_config_bin = pg_install_dir_versioned
.join(pg_version)
.join("bin")
.join("pg_config");
let inc_server_path: String = if pg_config_bin.exists() {
let output = Command::new(pg_config_bin)
.arg("--includedir-server")

View File

@@ -5,8 +5,8 @@ edition.workspace = true
license.workspace = true
[dependencies]
anyhow.workspace = true
bytes.workspace = true
byteorder.workspace = true
pin-project-lite.workspace = true
postgres-protocol.workspace = true
rand.workspace = true

View File

@@ -1,244 +0,0 @@
//! Provides `Framed` -- writing/flushing and reading Postgres messages to/from
//! the async stream based on (and buffered with) BytesMut. All functions are
//! cancellation safe.
//!
//! It is similar to what tokio_util::codec::Framed with appropriate codec
//! provides, but `FramedReader` and `FramedWriter` read/write parts can be used
//! separately without using split from futures::stream::StreamExt (which
//! allocates box[1] in polling internally). tokio::io::split is used for splitting
//! instead. Plus we customize error messages more than a single type for all io
//! calls.
//!
//! [1] https://docs.rs/futures-util/0.3.26/src/futures_util/lock/bilock.rs.html#107
use bytes::{Buf, BytesMut};
use std::{
future::Future,
io::{self, ErrorKind},
};
use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt, ReadHalf, WriteHalf};
use crate::{BeMessage, FeMessage, FeStartupPacket, ProtocolError};
const INITIAL_CAPACITY: usize = 8 * 1024;
/// Error on postgres connection: either IO (physical transport error) or
/// protocol violation.
#[derive(thiserror::Error, Debug)]
pub enum ConnectionError {
#[error(transparent)]
Io(#[from] io::Error),
#[error(transparent)]
Protocol(#[from] ProtocolError),
}
impl ConnectionError {
/// Proxy stream.rs uses only io::Error; provide it.
pub fn into_io_error(self) -> io::Error {
match self {
ConnectionError::Io(io) => io,
ConnectionError::Protocol(pe) => io::Error::new(io::ErrorKind::Other, pe.to_string()),
}
}
}
/// Wraps async io `stream`, providing messages to write/flush + read Postgres
/// messages.
pub struct Framed<S> {
stream: S,
read_buf: BytesMut,
write_buf: BytesMut,
}
impl<S> Framed<S> {
pub fn new(stream: S) -> Self {
Self {
stream,
read_buf: BytesMut::with_capacity(INITIAL_CAPACITY),
write_buf: BytesMut::with_capacity(INITIAL_CAPACITY),
}
}
/// Get a shared reference to the underlying stream.
pub fn get_ref(&self) -> &S {
&self.stream
}
/// Deconstruct into the underlying stream and read buffer.
pub fn into_inner(self) -> (S, BytesMut) {
(self.stream, self.read_buf)
}
/// Return new Framed with stream type transformed by async f, for TLS
/// upgrade.
pub async fn map_stream<S2, E, F, Fut>(self, f: F) -> Result<Framed<S2>, E>
where
F: FnOnce(S) -> Fut,
Fut: Future<Output = Result<S2, E>>,
{
let stream = f(self.stream).await?;
Ok(Framed {
stream,
read_buf: self.read_buf,
write_buf: self.write_buf,
})
}
}
impl<S: AsyncRead + Unpin> Framed<S> {
pub async fn read_startup_message(
&mut self,
) -> Result<Option<FeStartupPacket>, ConnectionError> {
read_message(&mut self.stream, &mut self.read_buf, FeStartupPacket::parse).await
}
pub async fn read_message(&mut self) -> Result<Option<FeMessage>, ConnectionError> {
read_message(&mut self.stream, &mut self.read_buf, FeMessage::parse).await
}
}
impl<S: AsyncWrite + Unpin> Framed<S> {
/// Write next message to the output buffer; doesn't flush.
pub fn write_message(&mut self, msg: &BeMessage<'_>) -> Result<(), ProtocolError> {
BeMessage::write(&mut self.write_buf, msg)
}
/// Flush out the buffer. This function is cancellation safe: it can be
/// interrupted and flushing will be continued in the next call.
pub async fn flush(&mut self) -> Result<(), io::Error> {
flush(&mut self.stream, &mut self.write_buf).await
}
/// Flush out the buffer and shutdown the stream.
pub async fn shutdown(&mut self) -> Result<(), io::Error> {
shutdown(&mut self.stream, &mut self.write_buf).await
}
}
impl<S: AsyncRead + AsyncWrite + Unpin> Framed<S> {
/// Split into owned read and write parts. Beware of potential issues with
/// using halves in different tasks on TLS stream:
/// https://github.com/tokio-rs/tls/issues/40
pub fn split(self) -> (FramedReader<S>, FramedWriter<S>) {
let (read_half, write_half) = tokio::io::split(self.stream);
let reader = FramedReader {
stream: read_half,
read_buf: self.read_buf,
};
let writer = FramedWriter {
stream: write_half,
write_buf: self.write_buf,
};
(reader, writer)
}
/// Join read and write parts back.
pub fn unsplit(reader: FramedReader<S>, writer: FramedWriter<S>) -> Self {
Self {
stream: reader.stream.unsplit(writer.stream),
read_buf: reader.read_buf,
write_buf: writer.write_buf,
}
}
}
/// Read-only version of `Framed`.
pub struct FramedReader<S> {
stream: ReadHalf<S>,
read_buf: BytesMut,
}
impl<S: AsyncRead + Unpin> FramedReader<S> {
pub async fn read_message(&mut self) -> Result<Option<FeMessage>, ConnectionError> {
read_message(&mut self.stream, &mut self.read_buf, FeMessage::parse).await
}
}
/// Write-only version of `Framed`.
pub struct FramedWriter<S> {
stream: WriteHalf<S>,
write_buf: BytesMut,
}
impl<S: AsyncWrite + Unpin> FramedWriter<S> {
/// Write next message to the output buffer; doesn't flush.
pub fn write_message_noflush(&mut self, msg: &BeMessage<'_>) -> Result<(), ProtocolError> {
BeMessage::write(&mut self.write_buf, msg)
}
/// Flush out the buffer. This function is cancellation safe: it can be
/// interrupted and flushing will be continued in the next call.
pub async fn flush(&mut self) -> Result<(), io::Error> {
flush(&mut self.stream, &mut self.write_buf).await
}
/// Flush out the buffer and shutdown the stream.
pub async fn shutdown(&mut self) -> Result<(), io::Error> {
shutdown(&mut self.stream, &mut self.write_buf).await
}
}
/// Read next message from the stream. Returns Ok(None), if EOF happened and we
/// don't have remaining data in the buffer. This function is cancellation safe:
/// you can drop future which is not yet complete and finalize reading message
/// with the next call.
///
/// Parametrized to allow reading startup or usual message, having different
/// format.
async fn read_message<S: AsyncRead + Unpin, M, P>(
stream: &mut S,
read_buf: &mut BytesMut,
parse: P,
) -> Result<Option<M>, ConnectionError>
where
P: Fn(&mut BytesMut) -> Result<Option<M>, ProtocolError>,
{
loop {
if let Some(msg) = parse(read_buf)? {
return Ok(Some(msg));
}
// If we can't build a frame yet, try to read more data and try again.
// Make sure we've got room for at least one byte to read to ensure
// that we don't get a spurious 0 that looks like EOF.
read_buf.reserve(1);
if stream.read_buf(read_buf).await? == 0 {
if read_buf.has_remaining() {
return Err(io::Error::new(
ErrorKind::UnexpectedEof,
"EOF with unprocessed data in the buffer",
)
.into());
} else {
return Ok(None); // clean EOF
}
}
}
}
async fn flush<S: AsyncWrite + Unpin>(
stream: &mut S,
write_buf: &mut BytesMut,
) -> Result<(), io::Error> {
while write_buf.has_remaining() {
let bytes_written = stream.write(write_buf.chunk()).await?;
if bytes_written == 0 {
return Err(io::Error::new(
ErrorKind::WriteZero,
"failed to write message",
));
}
// The advanced part will be garbage collected, likely during shifting
// data left on next attempt to write to buffer when free space is not
// enough.
write_buf.advance(bytes_written);
}
write_buf.clear();
stream.flush().await
}
async fn shutdown<S: AsyncWrite + Unpin>(
stream: &mut S,
write_buf: &mut BytesMut,
) -> Result<(), io::Error> {
flush(stream, write_buf).await?;
stream.shutdown().await
}

View File

@@ -2,18 +2,24 @@
//! <https://www.postgresql.org/docs/devel/protocol-message-formats.html>
//! on message formats.
pub mod framed;
// Tools for calling certain async methods in sync contexts.
pub mod sync;
use byteorder::{BigEndian, ReadBytesExt};
use anyhow::{ensure, Context, Result};
use bytes::{Buf, BufMut, Bytes, BytesMut};
use postgres_protocol::PG_EPOCH;
use serde::{Deserialize, Serialize};
use std::{
borrow::Cow,
collections::HashMap,
fmt, io, str,
fmt,
future::Future,
io::{self, Cursor},
str,
time::{Duration, SystemTime},
};
use sync::{AsyncishRead, SyncFuture};
use tokio::io::AsyncReadExt;
use tracing::{trace, warn};
pub type Oid = u32;
@@ -25,6 +31,7 @@ pub const TEXT_OID: Oid = 25;
#[derive(Debug)]
pub enum FeMessage {
StartupPacket(FeStartupPacket),
// Simple query.
Query(Bytes),
// Extended query protocol.
@@ -184,205 +191,260 @@ pub struct FeExecuteMessage {
#[derive(Debug)]
pub struct FeCloseMessage;
/// An error occured while parsing or serializing raw stream into Postgres
/// messages.
#[derive(thiserror::Error, Debug)]
pub enum ProtocolError {
/// Invalid packet was received from the client (e.g. unexpected message
/// type or broken len).
#[error("Protocol error: {0}")]
Protocol(String),
/// Failed to parse or, (unlikely), serialize a protocol message.
#[error("Message parse error: {0}")]
BadMessage(String),
/// Retry a read on EINTR
///
/// This runs the enclosed expression, and if it returns
/// Err(io::ErrorKind::Interrupted), retries it.
macro_rules! retry_read {
( $x:expr ) => {
loop {
match $x {
Err(e) if e.kind() == io::ErrorKind::Interrupted => continue,
res => break res,
}
}
};
}
impl ProtocolError {
/// Proxy stream.rs uses only io::Error; provide it.
/// An error occured during connection being open.
#[derive(thiserror::Error, Debug)]
pub enum ConnectionError {
/// IO error during writing to or reading from the connection socket.
#[error("Socket IO error: {0}")]
Socket(std::io::Error),
/// Invalid packet was received from client
#[error("Protocol error: {0}")]
Protocol(String),
/// Failed to parse a protocol mesage
#[error("Message parse error: {0}")]
MessageParse(anyhow::Error),
}
impl From<anyhow::Error> for ConnectionError {
fn from(e: anyhow::Error) -> Self {
Self::MessageParse(e)
}
}
impl ConnectionError {
pub fn into_io_error(self) -> io::Error {
io::Error::new(io::ErrorKind::Other, self.to_string())
match self {
ConnectionError::Socket(io) => io,
other => io::Error::new(io::ErrorKind::Other, other.to_string()),
}
}
}
impl FeMessage {
/// Read and parse one message from the `buf` input buffer. If there is at
/// least one valid message, returns it, advancing `buf`; redundant copies
/// are avoided, as thanks to `bytes` crate ptrs in parsed message point
/// directly into the `buf` (processed data is garbage collected after
/// parsed message is dropped).
/// Read one message from the stream.
/// This function returns `Ok(None)` in case of EOF.
/// One way to handle this properly:
///
/// Returns None if `buf` doesn't contain enough data for a single message.
/// For efficiency, tries to reserve large enough space in `buf` for the
/// next message in this case to save the repeated calls.
/// ```
/// # use std::io;
/// # use pq_proto::FeMessage;
/// #
/// # fn process_message(msg: FeMessage) -> anyhow::Result<()> {
/// # Ok(())
/// # };
/// #
/// fn do_the_job(stream: &mut (impl io::Read + Unpin)) -> anyhow::Result<()> {
/// while let Some(msg) = FeMessage::read(stream)? {
/// process_message(msg)?;
/// }
///
/// Returns Error if message is malformed, the only possible ErrorKind is
/// InvalidInput.
//
// Inspired by rust-postgres Message::parse.
pub fn parse(buf: &mut BytesMut) -> Result<Option<FeMessage>, ProtocolError> {
// Every message contains message type byte and 4 bytes len; can't do
// much without them.
if buf.len() < 5 {
let to_read = 5 - buf.len();
buf.reserve(to_read);
return Ok(None);
}
/// Ok(())
/// }
/// ```
#[inline(never)]
pub fn read(
stream: &mut (impl io::Read + Unpin),
) -> Result<Option<FeMessage>, ConnectionError> {
Self::read_fut(&mut AsyncishRead(stream)).wait()
}
// We shouldn't advance `buf` as probably full message is not there yet,
// so can't directly use Bytes::get_u32 etc.
let tag = buf[0];
let len = (&buf[1..5]).read_u32::<BigEndian>().unwrap();
if len < 4 {
return Err(ProtocolError::Protocol(format!(
"invalid message length {}",
len
)));
}
/// Read one message from the stream.
/// See documentation for `Self::read`.
pub fn read_fut<Reader>(
stream: &mut Reader,
) -> SyncFuture<Reader, impl Future<Output = Result<Option<FeMessage>, ConnectionError>> + '_>
where
Reader: tokio::io::AsyncRead + Unpin,
{
// We return a Future that's sync (has a `wait` method) if and only if the provided stream is SyncProof.
// SyncFuture contract: we are only allowed to await on sync-proof futures, the AsyncRead and
// AsyncReadExt methods of the stream.
SyncFuture::new(async move {
// Each libpq message begins with a message type byte, followed by message length
// If the client closes the connection, return None. But if the client closes the
// connection in the middle of a message, we will return an error.
let tag = match retry_read!(stream.read_u8().await) {
Ok(b) => b,
Err(e) if e.kind() == io::ErrorKind::UnexpectedEof => return Ok(None),
Err(e) => return Err(ConnectionError::Socket(e)),
};
// length field includes itself, but not message type.
let total_len = len as usize + 1;
if buf.len() < total_len {
// Don't have full message yet.
let to_read = total_len - buf.len();
buf.reserve(to_read);
return Ok(None);
}
// The message length includes itself, so it better be at least 4.
let len = retry_read!(stream.read_u32().await)
.map_err(ConnectionError::Socket)?
.checked_sub(4)
.ok_or_else(|| ConnectionError::Protocol("invalid message length".to_string()))?;
// got the message, advance buffer
let mut msg = buf.split_to(total_len).freeze();
msg.advance(5); // consume message type and len
let body = {
let mut buffer = vec![0u8; len as usize];
stream
.read_exact(&mut buffer)
.await
.map_err(ConnectionError::Socket)?;
Bytes::from(buffer)
};
match tag {
b'Q' => Ok(Some(FeMessage::Query(msg))),
b'P' => Ok(Some(FeParseMessage::parse(msg)?)),
b'D' => Ok(Some(FeDescribeMessage::parse(msg)?)),
b'E' => Ok(Some(FeExecuteMessage::parse(msg)?)),
b'B' => Ok(Some(FeBindMessage::parse(msg)?)),
b'C' => Ok(Some(FeCloseMessage::parse(msg)?)),
b'S' => Ok(Some(FeMessage::Sync)),
b'X' => Ok(Some(FeMessage::Terminate)),
b'd' => Ok(Some(FeMessage::CopyData(msg))),
b'c' => Ok(Some(FeMessage::CopyDone)),
b'f' => Ok(Some(FeMessage::CopyFail)),
b'p' => Ok(Some(FeMessage::PasswordMessage(msg))),
tag => Err(ProtocolError::Protocol(format!(
"unknown message tag: {tag},'{msg:?}'"
))),
}
match tag {
b'Q' => Ok(Some(FeMessage::Query(body))),
b'P' => Ok(Some(FeParseMessage::parse(body)?)),
b'D' => Ok(Some(FeDescribeMessage::parse(body)?)),
b'E' => Ok(Some(FeExecuteMessage::parse(body)?)),
b'B' => Ok(Some(FeBindMessage::parse(body)?)),
b'C' => Ok(Some(FeCloseMessage::parse(body)?)),
b'S' => Ok(Some(FeMessage::Sync)),
b'X' => Ok(Some(FeMessage::Terminate)),
b'd' => Ok(Some(FeMessage::CopyData(body))),
b'c' => Ok(Some(FeMessage::CopyDone)),
b'f' => Ok(Some(FeMessage::CopyFail)),
b'p' => Ok(Some(FeMessage::PasswordMessage(body))),
tag => {
return Err(ConnectionError::Protocol(format!(
"unknown message tag: {tag},'{body:?}'"
)))
}
}
})
}
}
impl FeStartupPacket {
/// Read and parse startup message from the `buf` input buffer. It is
/// different from [`FeMessage::parse`] because startup messages don't have
/// message type byte; otherwise, its comments apply.
pub fn parse(buf: &mut BytesMut) -> Result<Option<FeStartupPacket>, ProtocolError> {
/// Read startup message from the stream.
// XXX: It's tempting yet undesirable to accept `stream` by value,
// since such a change will cause user-supplied &mut references to be consumed
pub fn read(
stream: &mut (impl io::Read + Unpin),
) -> Result<Option<FeMessage>, ConnectionError> {
Self::read_fut(&mut AsyncishRead(stream)).wait()
}
/// Read startup message from the stream.
// XXX: It's tempting yet undesirable to accept `stream` by value,
// since such a change will cause user-supplied &mut references to be consumed
pub fn read_fut<Reader>(
stream: &mut Reader,
) -> SyncFuture<Reader, impl Future<Output = Result<Option<FeMessage>, ConnectionError>> + '_>
where
Reader: tokio::io::AsyncRead + Unpin,
{
const MAX_STARTUP_PACKET_LENGTH: usize = 10000;
const RESERVED_INVALID_MAJOR_VERSION: u32 = 1234;
const CANCEL_REQUEST_CODE: u32 = 5678;
const NEGOTIATE_SSL_CODE: u32 = 5679;
const NEGOTIATE_GSS_CODE: u32 = 5680;
// need at least 4 bytes with packet len
if buf.len() < 4 {
let to_read = 4 - buf.len();
buf.reserve(to_read);
return Ok(None);
}
SyncFuture::new(async move {
// Read length. If the connection is closed before reading anything (or before
// reading 4 bytes, to be precise), return None to indicate that the connection
// was closed. This matches the PostgreSQL server's behavior, which avoids noise
// in the log if the client opens connection but closes it immediately.
let len = match retry_read!(stream.read_u32().await) {
Ok(len) => len as usize,
Err(e) if e.kind() == io::ErrorKind::UnexpectedEof => return Ok(None),
Err(e) => return Err(ConnectionError::Socket(e)),
};
// We shouldn't advance `buf` as probably full message is not there yet,
// so can't directly use Bytes::get_u32 etc.
let len = (&buf[0..4]).read_u32::<BigEndian>().unwrap() as usize;
if len < 4 || len > MAX_STARTUP_PACKET_LENGTH {
return Err(ProtocolError::Protocol(format!(
"invalid startup packet message length {}",
len
)));
}
if buf.len() < len {
// Don't have full message yet.
let to_read = len - buf.len();
buf.reserve(to_read);
return Ok(None);
}
// got the message, advance buffer
let mut msg = buf.split_to(len).freeze();
msg.advance(4); // consume len
let request_code = msg.get_u32();
let req_hi = request_code >> 16;
let req_lo = request_code & ((1 << 16) - 1);
// StartupMessage, CancelRequest, SSLRequest etc are differentiated by request code.
let message = match (req_hi, req_lo) {
(RESERVED_INVALID_MAJOR_VERSION, CANCEL_REQUEST_CODE) => {
if msg.remaining() != 8 {
return Err(ProtocolError::BadMessage(
"CancelRequest message is malformed, backend PID / secret key missing"
.to_owned(),
));
}
FeStartupPacket::CancelRequest(CancelKeyData {
backend_pid: msg.get_i32(),
cancel_key: msg.get_i32(),
})
}
(RESERVED_INVALID_MAJOR_VERSION, NEGOTIATE_SSL_CODE) => {
// Requested upgrade to SSL (aka TLS)
FeStartupPacket::SslRequest
}
(RESERVED_INVALID_MAJOR_VERSION, NEGOTIATE_GSS_CODE) => {
// Requested upgrade to GSSAPI
FeStartupPacket::GssEncRequest
}
(RESERVED_INVALID_MAJOR_VERSION, unrecognized_code) => {
return Err(ProtocolError::Protocol(format!(
"Unrecognized request code {unrecognized_code}"
#[allow(clippy::manual_range_contains)]
if len < 4 || len > MAX_STARTUP_PACKET_LENGTH {
return Err(ConnectionError::Protocol(format!(
"invalid message length {len}"
)));
}
// TODO bail if protocol major_version is not 3?
(major_version, minor_version) => {
// StartupMessage
// Parse pairs of null-terminated strings (key, value).
// See `postgres: ProcessStartupPacket, build_startup_packet`.
let mut tokens = str::from_utf8(&msg)
.map_err(|_e| {
ProtocolError::BadMessage("StartupMessage params: invalid utf-8".to_owned())
})?
.strip_suffix('\0') // drop packet's own null
.ok_or_else(|| {
ProtocolError::Protocol(
"StartupMessage params: missing null terminator".to_string(),
)
})?
.split_terminator('\0');
let request_code =
retry_read!(stream.read_u32().await).map_err(ConnectionError::Socket)?;
let mut params = HashMap::new();
while let Some(name) = tokens.next() {
let value = tokens.next().ok_or_else(|| {
ProtocolError::Protocol(
"StartupMessage params: key without value".to_string(),
)
})?;
// the rest of startup packet are params
let params_len = len - 8;
let mut params_bytes = vec![0u8; params_len];
stream
.read_exact(params_bytes.as_mut())
.await
.map_err(ConnectionError::Socket)?;
params.insert(name.to_owned(), value.to_owned());
// Parse params depending on request code
let req_hi = request_code >> 16;
let req_lo = request_code & ((1 << 16) - 1);
let message = match (req_hi, req_lo) {
(RESERVED_INVALID_MAJOR_VERSION, CANCEL_REQUEST_CODE) => {
if params_len != 8 {
return Err(ConnectionError::Protocol(
"expected 8 bytes for CancelRequest params".to_string(),
));
}
let mut cursor = Cursor::new(params_bytes);
FeStartupPacket::CancelRequest(CancelKeyData {
backend_pid: cursor.read_i32().await.map_err(ConnectionError::Socket)?,
cancel_key: cursor.read_i32().await.map_err(ConnectionError::Socket)?,
})
}
FeStartupPacket::StartupMessage {
major_version,
minor_version,
params: StartupMessageParams { params },
(RESERVED_INVALID_MAJOR_VERSION, NEGOTIATE_SSL_CODE) => {
// Requested upgrade to SSL (aka TLS)
FeStartupPacket::SslRequest
}
}
};
Ok(Some(message))
(RESERVED_INVALID_MAJOR_VERSION, NEGOTIATE_GSS_CODE) => {
// Requested upgrade to GSSAPI
FeStartupPacket::GssEncRequest
}
(RESERVED_INVALID_MAJOR_VERSION, unrecognized_code) => {
return Err(ConnectionError::Protocol(format!(
"Unrecognized request code {unrecognized_code}"
)));
}
// TODO bail if protocol major_version is not 3?
(major_version, minor_version) => {
// Parse pairs of null-terminated strings (key, value).
// See `postgres: ProcessStartupPacket, build_startup_packet`.
let mut tokens = str::from_utf8(&params_bytes)
.context("StartupMessage params: invalid utf-8")?
.strip_suffix('\0') // drop packet's own null
.ok_or_else(|| {
ConnectionError::Protocol(
"StartupMessage params: missing null terminator".to_string(),
)
})?
.split_terminator('\0');
let mut params = HashMap::new();
while let Some(name) = tokens.next() {
let value = tokens.next().ok_or_else(|| {
ConnectionError::Protocol(
"StartupMessage params: key without value".to_string(),
)
})?;
params.insert(name.to_owned(), value.to_owned());
}
FeStartupPacket::StartupMessage {
major_version,
minor_version,
params: StartupMessageParams { params },
}
}
};
Ok(Some(FeMessage::StartupPacket(message)))
})
}
}
impl FeParseMessage {
fn parse(mut buf: Bytes) -> Result<FeMessage, ProtocolError> {
fn parse(mut buf: Bytes) -> anyhow::Result<FeMessage> {
// FIXME: the rust-postgres driver uses a named prepared statement
// for copy_out(). We're not prepared to handle that correctly. For
// now, just ignore the statement name, assuming that the client never
@@ -390,82 +452,55 @@ impl FeParseMessage {
let _pstmt_name = read_cstr(&mut buf)?;
let query_string = read_cstr(&mut buf)?;
if buf.remaining() < 2 {
return Err(ProtocolError::BadMessage(
"Parse message is malformed, nparams missing".to_string(),
));
}
let nparams = buf.get_i16();
if nparams != 0 {
return Err(ProtocolError::BadMessage(
"query params not implemented".to_string(),
));
}
ensure!(nparams == 0, "query params not implemented");
Ok(FeMessage::Parse(FeParseMessage { query_string }))
}
}
impl FeDescribeMessage {
fn parse(mut buf: Bytes) -> Result<FeMessage, ProtocolError> {
fn parse(mut buf: Bytes) -> anyhow::Result<FeMessage> {
let kind = buf.get_u8();
let _pstmt_name = read_cstr(&mut buf)?;
// FIXME: see FeParseMessage::parse
if kind != b'S' {
return Err(ProtocolError::BadMessage(
"only prepared statemement Describe is implemented".to_string(),
));
}
ensure!(
kind == b'S',
"only prepared statemement Describe is implemented"
);
Ok(FeMessage::Describe(FeDescribeMessage { kind }))
}
}
impl FeExecuteMessage {
fn parse(mut buf: Bytes) -> Result<FeMessage, ProtocolError> {
fn parse(mut buf: Bytes) -> anyhow::Result<FeMessage> {
let portal_name = read_cstr(&mut buf)?;
if buf.remaining() < 4 {
return Err(ProtocolError::BadMessage(
"FeExecuteMessage message is malformed, maxrows missing".to_string(),
));
}
let maxrows = buf.get_i32();
if !portal_name.is_empty() {
return Err(ProtocolError::BadMessage(
"named portals not implemented".to_string(),
));
}
if maxrows != 0 {
return Err(ProtocolError::BadMessage(
"row limit in Execute message not implemented".to_string(),
));
}
ensure!(portal_name.is_empty(), "named portals not implemented");
ensure!(maxrows == 0, "row limit in Execute message not implemented");
Ok(FeMessage::Execute(FeExecuteMessage { maxrows }))
}
}
impl FeBindMessage {
fn parse(mut buf: Bytes) -> Result<FeMessage, ProtocolError> {
fn parse(mut buf: Bytes) -> anyhow::Result<FeMessage> {
let portal_name = read_cstr(&mut buf)?;
let _pstmt_name = read_cstr(&mut buf)?;
// FIXME: see FeParseMessage::parse
if !portal_name.is_empty() {
return Err(ProtocolError::BadMessage(
"named portals not implemented".to_string(),
));
}
ensure!(portal_name.is_empty(), "named portals not implemented");
Ok(FeMessage::Bind(FeBindMessage))
}
}
impl FeCloseMessage {
fn parse(mut buf: Bytes) -> Result<FeMessage, ProtocolError> {
fn parse(mut buf: Bytes) -> anyhow::Result<FeMessage> {
let _kind = buf.get_u8();
let _pstmt_or_portal_name = read_cstr(&mut buf)?;
@@ -494,7 +529,6 @@ pub enum BeMessage<'a> {
CloseComplete,
// None means column is NULL
DataRow(&'a [Option<&'a [u8]>]),
// None errcode means internal_error will be sent.
ErrorResponse(&'a str, Option<&'a [u8; 5]>),
/// Single byte - used in response to SSLRequest/GSSENCRequest.
EncryptionResponse(bool),
@@ -525,11 +559,6 @@ impl<'a> BeMessage<'a> {
value: b"UTF8",
};
pub const INTEGER_DATETIMES: Self = Self::ParameterStatus {
name: b"integer_datetimes",
value: b"on",
};
/// Build a [`BeMessage::ParameterStatus`] holding the server version.
pub fn server_version(version: &'a str) -> Self {
Self::ParameterStatus {
@@ -608,7 +637,7 @@ impl RowDescriptor<'_> {
#[derive(Debug)]
pub struct XLogDataBody<'a> {
pub wal_start: u64,
pub wal_end: u64, // current end of WAL on the server
pub wal_end: u64,
pub timestamp: i64,
pub data: &'a [u8],
}
@@ -648,11 +677,12 @@ fn write_body<R>(buf: &mut BytesMut, f: impl FnOnce(&mut BytesMut) -> R) -> R {
}
/// Safe write of s into buf as cstring (String in the protocol).
fn write_cstr(s: impl AsRef<[u8]>, buf: &mut BytesMut) -> Result<(), ProtocolError> {
fn write_cstr(s: impl AsRef<[u8]>, buf: &mut BytesMut) -> io::Result<()> {
let bytes = s.as_ref();
if bytes.contains(&0) {
return Err(ProtocolError::BadMessage(
"string contains embedded null".to_owned(),
return Err(io::Error::new(
io::ErrorKind::InvalidInput,
"string contains embedded null",
));
}
buf.put_slice(bytes);
@@ -660,27 +690,22 @@ fn write_cstr(s: impl AsRef<[u8]>, buf: &mut BytesMut) -> Result<(), ProtocolErr
Ok(())
}
/// Read cstring from buf, advancing it.
fn read_cstr(buf: &mut Bytes) -> Result<Bytes, ProtocolError> {
let pos = buf
.iter()
.position(|x| *x == 0)
.ok_or_else(|| ProtocolError::BadMessage("missing cstring terminator".to_owned()))?;
let result = buf.split_to(pos);
fn read_cstr(buf: &mut Bytes) -> anyhow::Result<Bytes> {
let pos = buf.iter().position(|x| *x == 0);
let result = buf.split_to(pos.context("missing terminator")?);
buf.advance(1); // drop the null terminator
Ok(result)
}
pub const SQLSTATE_INTERNAL_ERROR: &[u8; 5] = b"XX000";
pub const SQLSTATE_SUCCESSFUL_COMPLETION: &[u8; 5] = b"00000";
impl<'a> BeMessage<'a> {
/// Serialize `message` to the given `buf`.
/// Apart from smart memory managemet, BytesMut is good here as msg len
/// precedes its body and it is handy to write it down first and then fill
/// the length. With Write we would have to either calc it manually or have
/// one more buffer.
pub fn write(buf: &mut BytesMut, message: &BeMessage) -> Result<(), ProtocolError> {
/// Write message to the given buf.
// Unlike the reading side, we use BytesMut
// here as msg len precedes its body and it is handy to write it down first
// and then fill the length. With Write we would have to either calc it
// manually or have one more buffer.
pub fn write(buf: &mut BytesMut, message: &BeMessage) -> io::Result<()> {
match message {
BeMessage::AuthenticationOk => {
buf.put_u8(b'R');
@@ -725,7 +750,7 @@ impl<'a> BeMessage<'a> {
buf.put_slice(extra);
}
}
Ok(())
Ok::<_, io::Error>(())
})?;
}
@@ -829,7 +854,7 @@ impl<'a> BeMessage<'a> {
write_cstr(error_msg, buf)?;
buf.put_u8(0); // terminator
Ok(())
Ok::<_, io::Error>(())
})?;
}
@@ -852,7 +877,7 @@ impl<'a> BeMessage<'a> {
write_cstr(error_msg.as_bytes(), buf)?;
buf.put_u8(0); // terminator
Ok(())
Ok::<_, io::Error>(())
})?;
}
@@ -907,7 +932,7 @@ impl<'a> BeMessage<'a> {
buf.put_i32(-1); /* typmod */
buf.put_i16(0); /* format code */
}
Ok(())
Ok::<_, io::Error>(())
})?;
}
@@ -974,7 +999,7 @@ impl ReplicationFeedback {
// null-terminated string - key,
// uint32 - value length in bytes
// value itself
pub fn serialize(&self, buf: &mut BytesMut) {
pub fn serialize(&self, buf: &mut BytesMut) -> Result<()> {
buf.put_u8(REPLICATION_FEEDBACK_FIELDS_NUMBER); // # of keys
buf.put_slice(b"current_timeline_size\0");
buf.put_i32(8);
@@ -999,6 +1024,7 @@ impl ReplicationFeedback {
buf.put_slice(b"ps_replytime\0");
buf.put_i32(8);
buf.put_i64(timestamp);
Ok(())
}
// Deserialize ReplicationFeedback message
@@ -1066,7 +1092,7 @@ mod tests {
// because it is rounded up to microseconds during serialization.
rf.ps_replytime = *PG_EPOCH + Duration::from_secs(100_000_000);
let mut data = BytesMut::new();
rf.serialize(&mut data);
rf.serialize(&mut data).unwrap();
let rf_parsed = ReplicationFeedback::parse(data.freeze());
assert_eq!(rf, rf_parsed);
@@ -1081,7 +1107,7 @@ mod tests {
// because it is rounded up to microseconds during serialization.
rf.ps_replytime = *PG_EPOCH + Duration::from_secs(100_000_000);
let mut data = BytesMut::new();
rf.serialize(&mut data);
rf.serialize(&mut data).unwrap();
// Add an extra field to the buffer and adjust number of keys
if let Some(first) = data.first_mut() {
@@ -1123,6 +1149,15 @@ mod tests {
let params = make_params("foo\\ bar \\ \\\\ baz\\ lol");
assert_eq!(split_options(&params), ["foo bar", " \\", "baz ", "lol"]);
}
// Make sure that `read` is sync/async callable
async fn _assert(stream: &mut (impl tokio::io::AsyncRead + Unpin)) {
let _ = FeMessage::read(&mut [].as_ref());
let _ = FeMessage::read_fut(stream).await;
let _ = FeStartupPacket::read(&mut [].as_ref());
let _ = FeStartupPacket::read_fut(stream).await;
}
}
fn terminate_code(code: &[u8; 5]) -> [u8; 6] {

179
libs/pq_proto/src/sync.rs Normal file
View File

@@ -0,0 +1,179 @@
use pin_project_lite::pin_project;
use std::future::Future;
use std::marker::PhantomData;
use std::pin::Pin;
use std::{io, task};
pin_project! {
/// We use this future to mark certain methods
/// as callable in both sync and async modes.
#[repr(transparent)]
pub struct SyncFuture<S, T: Future> {
#[pin]
inner: T,
_marker: PhantomData<S>,
}
}
/// This wrapper lets us synchronously wait for inner future's completion
/// (see [`SyncFuture::wait`]) **provided that `S` implements [`SyncProof`]**.
/// For instance, `S` may be substituted with types implementing
/// [`tokio::io::AsyncRead`], but it's not the only viable option.
impl<S, T: Future> SyncFuture<S, T> {
/// NOTE: caller should carefully pick a type for `S`,
/// because we don't want to enable [`SyncFuture::wait`] when
/// it's in fact impossible to run the future synchronously.
/// Violation of this contract will not cause UB, but
/// panics and async event loop freezes won't please you.
///
/// Example:
///
/// ```
/// # use pq_proto::sync::SyncFuture;
/// # use std::future::Future;
/// # use tokio::io::AsyncReadExt;
/// #
/// // Parse a pair of numbers from a stream
/// pub fn parse_pair<Reader>(
/// stream: &mut Reader,
/// ) -> SyncFuture<Reader, impl Future<Output = anyhow::Result<(u32, u64)>> + '_>
/// where
/// Reader: tokio::io::AsyncRead + Unpin,
/// {
/// // If `Reader` is a `SyncProof`, this will give caller
/// // an opportunity to use `SyncFuture::wait`, because
/// // `.await` will always result in `Poll::Ready`.
/// SyncFuture::new(async move {
/// let x = stream.read_u32().await?;
/// let y = stream.read_u64().await?;
/// Ok((x, y))
/// })
/// }
/// ```
pub fn new(inner: T) -> Self {
Self {
inner,
_marker: PhantomData,
}
}
}
impl<S, T: Future> Future for SyncFuture<S, T> {
type Output = T::Output;
/// In async code, [`SyncFuture`] behaves like a regular wrapper.
#[inline(always)]
fn poll(self: Pin<&mut Self>, cx: &mut task::Context<'_>) -> task::Poll<Self::Output> {
self.project().inner.poll(cx)
}
}
/// Postulates that we can call [`SyncFuture::wait`].
/// If implementer is also a [`Future`], it should always
/// return [`task::Poll::Ready`] from [`Future::poll`].
///
/// Each implementation should document which futures
/// specifically are being declared sync-proof.
pub trait SyncPostulate {}
impl<T: SyncPostulate> SyncPostulate for &T {}
impl<T: SyncPostulate> SyncPostulate for &mut T {}
impl<P: SyncPostulate, T: Future> SyncFuture<P, T> {
/// Synchronously wait for future completion.
pub fn wait(mut self) -> T::Output {
const RAW_WAKER: task::RawWaker = task::RawWaker::new(
std::ptr::null(),
&task::RawWakerVTable::new(
|_| RAW_WAKER,
|_| panic!("SyncFuture: failed to wake"),
|_| panic!("SyncFuture: failed to wake by ref"),
|_| { /* drop is no-op */ },
),
);
// SAFETY: We never move `self` during this call;
// furthermore, it will be dropped in the end regardless of panics
let this = unsafe { Pin::new_unchecked(&mut self) };
// SAFETY: This waker doesn't do anything apart from panicking
let waker = unsafe { task::Waker::from_raw(RAW_WAKER) };
let context = &mut task::Context::from_waker(&waker);
match this.poll(context) {
task::Poll::Ready(res) => res,
_ => panic!("SyncFuture: unexpected pending!"),
}
}
}
/// This wrapper turns any [`std::io::Read`] into a blocking [`tokio::io::AsyncRead`],
/// which lets us abstract over sync & async readers in methods returning [`SyncFuture`].
/// NOTE: you **should not** use this in async code.
#[repr(transparent)]
pub struct AsyncishRead<T: io::Read + Unpin>(pub T);
/// This lets us call [`SyncFuture<AsyncishRead<_>, _>::wait`],
/// and allows the future to await on any of the [`AsyncRead`]
/// and [`AsyncReadExt`] methods on `AsyncishRead`.
impl<T: io::Read + Unpin> SyncPostulate for AsyncishRead<T> {}
impl<T: io::Read + Unpin> tokio::io::AsyncRead for AsyncishRead<T> {
#[inline(always)]
fn poll_read(
mut self: Pin<&mut Self>,
_cx: &mut task::Context<'_>,
buf: &mut tokio::io::ReadBuf<'_>,
) -> task::Poll<io::Result<()>> {
task::Poll::Ready(
// `Read::read` will block, meaning we don't need a real event loop!
self.0
.read(buf.initialize_unfilled())
.map(|sz| buf.advance(sz)),
)
}
}
#[cfg(test)]
mod tests {
use super::*;
use tokio::io::{AsyncReadExt, AsyncWriteExt};
// async helper(stream: &mut impl AsyncRead) -> io::Result<u32>
fn bytes_add<Reader>(
stream: &mut Reader,
) -> SyncFuture<Reader, impl Future<Output = io::Result<u32>> + '_>
where
Reader: tokio::io::AsyncRead + Unpin,
{
SyncFuture::new(async move {
let a = stream.read_u32().await?;
let b = stream.read_u32().await?;
Ok(a + b)
})
}
#[test]
fn test_sync() {
let bytes = [100u32.to_be_bytes(), 200u32.to_be_bytes()].concat();
let res = bytes_add(&mut AsyncishRead(&mut &bytes[..]))
.wait()
.unwrap();
assert_eq!(res, 300);
}
// We need a single-threaded executor for this test
#[tokio::test(flavor = "current_thread")]
async fn test_async() {
let (mut tx, mut rx) = tokio::net::UnixStream::pair().unwrap();
let write = async move {
tx.write_u32(100).await?;
tx.write_u32(200).await?;
Ok(())
};
let (res, ()) = tokio::try_join!(bytes_add(&mut rx), write).unwrap();
assert_eq!(res, 300);
}
}

View File

@@ -111,7 +111,7 @@ pub trait RemoteStorage: Send + Sync + 'static {
}
pub struct Download {
pub download_stream: Pin<Box<dyn io::AsyncRead + Unpin + Send + Sync>>,
pub download_stream: Pin<Box<dyn io::AsyncRead + Unpin + Send>>,
/// Extra key-value data, associated with the current remote file.
pub metadata: Option<StorageMetadata>,
}

View File

@@ -12,37 +12,42 @@ anyhow.workspace = true
bincode.workspace = true
bytes.workspace = true
heapless.workspace = true
hex = { workspace = true, features = ["serde"] }
hyper = { workspace = true, features = ["full"] }
futures = { workspace = true}
jsonwebtoken.workspace = true
nix.workspace = true
once_cell.workspace = true
pin-project-lite.workspace = true
routerify.workspace = true
serde.workspace = true
serde_json.workspace = true
signal-hook.workspace = true
thiserror.workspace = true
tokio.workspace = true
tokio-rustls.workspace = true
tracing.workspace = true
tracing-subscriber = { workspace = true, features = ["json"] }
nix.workspace = true
signal-hook.workspace = true
rand.workspace = true
jsonwebtoken.workspace = true
hex = { workspace = true, features = ["serde"] }
rustls.workspace = true
rustls-split.workspace = true
git-version.workspace = true
serde_with.workspace = true
once_cell.workspace = true
strum.workspace = true
strum_macros.workspace = true
url.workspace = true
uuid = { version = "1.2", features = ["v4", "serde"] }
metrics.workspace = true
workspace_hack.workspace = true
pq_proto.workspace = true
workspace_hack.workspace = true
url.workspace = true
uuid = { version = "1.2", features = ["v4", "serde"] }
[dev-dependencies]
byteorder.workspace = true
bytes.workspace = true
criterion.workspace = true
hex-literal.workspace = true
tempfile.workspace = true
criterion.workspace = true
rustls-pemfile.workspace = true
[[bench]]
name = "benchmarks"

View File

@@ -9,28 +9,16 @@ use std::path::Path;
use anyhow::Result;
use jsonwebtoken::{
decode, encode, Algorithm, Algorithm::*, DecodingKey, EncodingKey, Header, TokenData,
Validation,
decode, encode, Algorithm, DecodingKey, EncodingKey, Header, TokenData, Validation,
};
use serde::{Deserialize, Serialize};
use serde_with::{serde_as, DisplayFromStr};
use crate::id::TenantId;
/// Algorithms accepted during validation.
///
/// Accept all RSA-based algorithms. We pass this list to jsonwebtoken::decode,
/// which checks that the algorithm in the token is one of these.
///
/// XXX: It also fails the validation if there are any algorithms in this list that belong
/// to different family than the token's algorithm. In other words, we can *not* list any
/// non-RSA algorithms here, or the validation always fails with InvalidAlgorithm error.
const ACCEPTED_ALGORITHMS: &[Algorithm] = &[RS256, RS384, RS512];
const JWT_ALGORITHM: Algorithm = Algorithm::RS256;
/// Algorithm to use when generating a new token in [`encode_from_key_file`]
const ENCODE_ALGORITHM: Algorithm = Algorithm::RS256;
#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
#[derive(Debug, Serialize, Deserialize, Clone)]
#[serde(rename_all = "lowercase")]
pub enum Scope {
// Provides access to all data for a specific tenant (specified in `struct Claims` below)
@@ -45,9 +33,8 @@ pub enum Scope {
SafekeeperData,
}
/// JWT payload. See docs/authentication.md for the format
#[serde_as]
#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
#[derive(Debug, Serialize, Deserialize, Clone)]
pub struct Claims {
#[serde(default)]
#[serde_as(as = "Option<DisplayFromStr>")]
@@ -68,8 +55,7 @@ pub struct JwtAuth {
impl JwtAuth {
pub fn new(decoding_key: DecodingKey) -> Self {
let mut validation = Validation::default();
validation.algorithms = ACCEPTED_ALGORITHMS.into();
let mut validation = Validation::new(JWT_ALGORITHM);
// The default 'required_spec_claims' is 'exp'. But we don't want to require
// expiration.
validation.required_spec_claims = [].into();
@@ -100,113 +86,5 @@ impl std::fmt::Debug for JwtAuth {
// this function is used only for testing purposes in CLI e g generate tokens during init
pub fn encode_from_key_file(claims: &Claims, key_data: &[u8]) -> Result<String> {
let key = EncodingKey::from_rsa_pem(key_data)?;
Ok(encode(&Header::new(ENCODE_ALGORITHM), claims, &key)?)
}
#[cfg(test)]
mod tests {
use super::*;
use std::str::FromStr;
// generated with:
//
// openssl genpkey -algorithm rsa -out storage-auth-priv.pem
// openssl pkey -in storage-auth-priv.pem -pubout -out storage-auth-pub.pem
const TEST_PUB_KEY_RSA: &[u8] = br#"
-----BEGIN PUBLIC KEY-----
MIIBIjANBgkqhkiG9w0BAQEFAAOCAQ8AMIIBCgKCAQEAy6OZ+/kQXcueVJA/KTzO
v4ljxylc/Kcb0sXWuXg1GB8k3nDA1gK66LFYToH0aTnqrnqG32Vu6wrhwuvqsZA7
jQvP0ZePAbWhpEqho7EpNunDPcxZ/XDy5TQlB1P58F9I3lkJXDC+DsHYLuuzwhAv
vo2MtWRdYlVHblCVLyZtANHhUMp2HUhgjHnJh5UrLIKOl4doCBxkM3rK0wjKsNCt
M92PCR6S9rvYzldfeAYFNppBkEQrXt2CgUqZ4KaS4LXtjTRUJxljijA4HWffhxsr
euRu3ufq8kVqie7fum0rdZZSkONmce0V0LesQ4aE2jB+2Sn48h6jb4dLXGWdq8TV
wQIDAQAB
-----END PUBLIC KEY-----
"#;
const TEST_PRIV_KEY_RSA: &[u8] = br#"
-----BEGIN PRIVATE KEY-----
MIIEvQIBADANBgkqhkiG9w0BAQEFAASCBKcwggSjAgEAAoIBAQDLo5n7+RBdy55U
kD8pPM6/iWPHKVz8pxvSxda5eDUYHyTecMDWArrosVhOgfRpOequeobfZW7rCuHC
6+qxkDuNC8/Rl48BtaGkSqGjsSk26cM9zFn9cPLlNCUHU/nwX0jeWQlcML4Owdgu
67PCEC++jYy1ZF1iVUduUJUvJm0A0eFQynYdSGCMecmHlSssgo6Xh2gIHGQzesrT
CMqw0K0z3Y8JHpL2u9jOV194BgU2mkGQRCte3YKBSpngppLgte2NNFQnGWOKMDgd
Z9+HGyt65G7e5+ryRWqJ7t+6bSt1llKQ42Zx7RXQt6xDhoTaMH7ZKfjyHqNvh0tc
ZZ2rxNXBAgMBAAECggEAVz3u4Wlx3o02dsoZlSQs+xf0PEX3RXKeU+1YMbtTG9Nz
6yxpIQaoZrpbt76rJE2gwkFR+PEu1NmjoOuLb6j4KlQuI4AHz1auOoGSwFtM6e66
K4aZ4x95oEJ3vqz2fkmEIWYJwYpMUmwvnuJx76kZm0xvROMLsu4QHS2+zCVtO5Tr
hvS05IMVuZ2TdQBZw0+JaFdwXbgDjQnQGY5n9MoTWSx1a4s/FF4Eby65BbDutcpn
Vt3jQAOmO1X2kbPeWSGuPJRzyUs7Kg8qfeglBIR3ppGP3vPYAdWX+ho00bmsVkSp
Q8vjul6C3WiM+kjwDxotHSDgbl/xldAl7OqPh0bfAQKBgQDnycXuq14Vg8nZvyn9
rTnvucO8RBz5P6G+FZ+44cAS2x79+85onARmMnm+9MKYLSMo8fOvsK034NDI68XM
04QQ/vlfouvFklMTGJIurgEImTZbGCmlMYCvFyIxaEWixon8OpeI4rFe4Hmbiijh
PxhxWg221AwvBS2sco8J/ylEkQKBgQDg6Rh2QYb/j0Wou1rJPbuy3NhHofd5Rq35
4YV3f2lfVYcPrgRhwe3T9SVII7Dx8LfwzsX5TAlf48ESlI3Dzv40uOCDM+xdtBRI
r96SfSm+jup6gsXU3AsdNkrRK3HoOG9Z/TkrUp213QAIlVnvIx65l4ckFMlpnPJ0
lo1LDXZWMQKBgFArzjZ7N5OhfdO+9zszC3MLgdRAivT7OWqR+CjujIz5FYMr8Xzl
WfAvTUTrS9Nu6VZkObFvHrrRG+YjBsuN7YQjbQXTSFGSBwH34bgbn2fl9pMTjHQC
50uoaL9GHa/rlBaV/YvvPQJgCi/uXa1rMX0jdNLkDULGO8IF7cu7Yf7BAoGBAIUU
J29BkpmAst0GDs/ogTlyR18LTR0rXyHt+UUd1MGeH859TwZw80JpWWf4BmkB4DTS
hH3gKePdJY7S65ci0XNsuRupC4DeXuorde0DtkGU2tUmr9wlX0Ynq9lcdYfMbMa4
eK1TsxG69JwfkxlWlIWITWRiEFM3lJa7xlrUWmLhAoGAFpKWF/hn4zYg3seU9gai
EYHKSbhxA4mRb+F0/9IlCBPMCqFrL5yftUsYIh2XFKn8+QhO97Nmk8wJSK6TzQ5t
ZaSRmgySrUUhx4nZ/MgqWCFv8VUbLM5MBzwxPKhXkSTfR4z2vLYLJwVY7Tb4kZtp
8ismApXVGHpOCstzikV9W7k=
-----END PRIVATE KEY-----
"#;
#[test]
fn test_decode() -> Result<(), anyhow::Error> {
let expected_claims = Claims {
tenant_id: Some(TenantId::from_str("3d1f7595b468230304e0b73cecbcb081")?),
scope: Scope::Tenant,
};
// Here are tokens containing the following payload, signed using TEST_PRIV_KEY_RSA
// using RS512, RS384 and RS256 algorithms:
//
// ```
// {
// "scope": "tenant",
// "tenant_id": "3d1f7595b468230304e0b73cecbcb081",
// "iss": "neon.controlplane",
// "exp": 1709200879,
// "iat": 1678442479
// }
// ```
//
// These were encoded with the online debugger at https://jwt.io
//
let encoded_rs512 = "eyJhbGciOiJSUzUxMiIsInR5cCI6IkpXVCJ9.eyJzY29wZSI6InRlbmFudCIsInRlbmFudF9pZCI6IjNkMWY3NTk1YjQ2ODIzMDMwNGUwYjczY2VjYmNiMDgxIiwiaXNzIjoibmVvbi5jb250cm9scGxhbmUiLCJleHAiOjE3MDkyMDA4NzksImlhdCI6MTY3ODQ0MjQ3OX0.QmqfteDQmDGoxQ5EFkasbt35Lx0W0Nh63muQnYZvFq93DSh4ZbOG9Mc4yaiXZoiS5HgeKtFKv3mbWkDqjz3En06aY17hWwguBtAsGASX48lYeCPADYGlGAuaWnOnVRwe3iiOC7tvPFvwX_45S84X73sNUXyUiXv6nLdcDqVXudtNrGST_DnZDnjuUJX11w7sebtKqQQ8l9-iGHiXOl5yevpMCoB1OcTWcT6DfDtffoNuMHDC3fyhmEGG5oKAt1qBybqAIiyC9-UBAowRZXhdfxrzUl-I9jzKWvk85c5ulhVRwbPeP6TTTlPKwFzBNHg1i2U-1GONew5osQ3aoptwsA";
let encoded_rs384 = "eyJhbGciOiJSUzM4NCIsInR5cCI6IkpXVCJ9.eyJzY29wZSI6InRlbmFudCIsInRlbmFudF9pZCI6IjNkMWY3NTk1YjQ2ODIzMDMwNGUwYjczY2VjYmNiMDgxIiwiaXNzIjoibmVvbi5jb250cm9scGxhbmUiLCJleHAiOjE3MDkyMDA4NzksImlhdCI6MTY3ODQ0MjQ3OX0.qqk4nkxKzOJP38c_g57_w_SfdQVmCsDT_bsLmdFj_N6LIB22gr6U6_P_5mvk3pIAsp0VCTDwPrCU908TxqjibEkwvQoJwbogHamSGHpD7eJBxGblSnA-Nr3MlEMxpFtec8QokSm6C5mH7DoBYjB2xzeOlxAmpR2GAzInKiMkU4kZ_OcqqrmVcMXY_6VnbxZWMekuw56zE1-PP_qNF1HvYOH-P08ONP8qdo5UPtBG7QBEFlCqZXJZCFihQaI4Vzil9rDuZGCm3I7xQJ8-yh1PX3BTbGo8EzqLdRyBeTpr08UTuRbp_MJDWevHpP3afvJetAItqZXIoZQrbJjcByHqKw";
let encoded_rs256 = "eyJhbGciOiJSUzI1NiIsInR5cCI6IkpXVCJ9.eyJzY29wZSI6InRlbmFudCIsInRlbmFudF9pZCI6IjNkMWY3NTk1YjQ2ODIzMDMwNGUwYjczY2VjYmNiMDgxIiwiaXNzIjoibmVvbi5jb250cm9scGxhbmUiLCJleHAiOjE3MDkyMDA4NzksImlhdCI6MTY3ODQ0MjQ3OX0.dF2N9KXG8ftFKHYbd5jQtXMQqv0Ej8FISGp1b_dmqOCotXj5S1y2AWjwyB_EXHM77JXfbEoJPAPrFFBNfd8cWtkCSTvpxWoHaecGzegDFGv5ZSc5AECFV1Daahc3PI3jii9wEiGkFOiwiBNfZ5INomOAsV--XXxlqIwKbTcgSYI7lrOTfecXAbAHiMKQlQYiIBSGnytRCgafhRkyGzPAL8ismthFJ9RHfeejyskht-9GbVHURw02bUyijuHEulpf9eEY3ZiB28de6jnCdU7ftIYaUMaYWt0nZQGkzxKPSfSLZNy14DTOYLDS04DVstWQPqnCUW_ojg0wJETOOfo9Zw";
// Check that RS512, RS384 and RS256 tokens can all be validated
let auth = JwtAuth::new(DecodingKey::from_rsa_pem(TEST_PUB_KEY_RSA)?);
for encoded in [encoded_rs512, encoded_rs384, encoded_rs256] {
let claims_from_token = auth.decode(encoded)?.claims;
assert_eq!(claims_from_token, expected_claims);
}
Ok(())
}
#[test]
fn test_encode() -> Result<(), anyhow::Error> {
let claims = Claims {
tenant_id: Some(TenantId::from_str("3d1f7595b468230304e0b73cecbcb081")?),
scope: Scope::Tenant,
};
let encoded = encode_from_key_file(&claims, TEST_PRIV_KEY_RSA)?;
// decode it back
let auth = JwtAuth::new(DecodingKey::from_rsa_pem(TEST_PUB_KEY_RSA)?);
let decoded = auth.decode(&encoded)?;
assert_eq!(decoded.claims, claims);
Ok(())
}
Ok(encode(&Header::new(JWT_ALGORITHM), claims, &key)?)
}

View File

@@ -11,7 +11,7 @@ where
P: AsRef<Path>,
{
fn is_empty_dir(&self) -> io::Result<bool> {
Ok(fs::read_dir(self)?.next().is_none())
Ok(fs::read_dir(self)?.into_iter().next().is_none())
}
}

View File

@@ -3,14 +3,14 @@ use crate::http::error;
use anyhow::{anyhow, Context};
use hyper::header::{HeaderName, AUTHORIZATION};
use hyper::http::HeaderValue;
use hyper::Method;
use hyper::{header::CONTENT_TYPE, Body, Request, Response, Server};
use hyper::{Method, StatusCode};
use metrics::{register_int_counter, Encoder, IntCounter, TextEncoder};
use once_cell::sync::Lazy;
use routerify::ext::RequestExt;
use routerify::{Middleware, RequestInfo, Router, RouterBuilder, RouterService};
use tokio::task::JoinError;
use tracing::{self, debug, info, info_span, warn, Instrument};
use tracing;
use std::future::Future;
use std::net::TcpListener;
@@ -32,77 +32,31 @@ static X_REQUEST_ID_HEADER: HeaderName = HeaderName::from_static(X_REQUEST_ID_HE
#[derive(Debug, Default, Clone)]
struct RequestId(String);
/// Adds a tracing info_span! instrumentation around the handler events,
/// logs the request start and end events for non-GET requests and non-200 responses.
///
/// Use this to distinguish between logs of different HTTP requests: every request handler wrapped
/// in this type will get request info logged in the wrapping span, including the unique request ID.
///
/// There could be other ways to implement similar functionality:
///
/// * procmacros placed on top of all handler methods
/// With all the drawbacks of procmacros, brings no difference implementation-wise,
/// and little code reduction compared to the existing approach.
///
/// * Another `TraitExt` with e.g. the `get_with_span`, `post_with_span` methods to do similar logic,
/// implemented for [`RouterBuilder`].
/// Could be simpler, but we don't want to depend on [`routerify`] more, targeting to use other library later.
///
/// * In theory, a span guard could've been created in a pre-request middleware and placed into a global collection, to be dropped
/// later, in a post-response middleware.
/// Due to suspendable nature of the futures, would give contradictive results which is exactly the opposite of what `tracing-futures`
/// tries to achive with its `.instrument` used in the current approach.
///
/// If needed, a declarative macro to substitute the |r| ... closure boilerplate could be introduced.
pub struct RequestSpan<E, R, H>(pub H)
where
E: Into<Box<dyn std::error::Error + Send + Sync>> + 'static,
R: Future<Output = Result<Response<Body>, E>> + Send + 'static,
H: Fn(Request<Body>) -> R + Send + Sync + 'static;
async fn logger(res: Response<Body>, info: RequestInfo) -> Result<Response<Body>, ApiError> {
let request_id = info.context::<RequestId>().unwrap_or_default().0;
impl<E, R, H> RequestSpan<E, R, H>
where
E: Into<Box<dyn std::error::Error + Send + Sync>> + 'static,
R: Future<Output = Result<Response<Body>, E>> + Send + 'static,
H: Fn(Request<Body>) -> R + Send + Sync + 'static,
{
/// Creates a tracing span around inner request handler and executes the request handler in the contex of that span.
/// Use as `|r| RequestSpan(my_handler).handle(r)` instead of `my_handler` as the request handler to get the span enabled.
pub async fn handle(self, request: Request<Body>) -> Result<Response<Body>, E> {
let request_id = request.context::<RequestId>().unwrap_or_default().0;
let method = request.method();
let path = request.uri().path();
let request_span = info_span!("request", %method, %path, %request_id);
// cannot factor out the Level to avoid the repetition
// because tracing can only work with const Level
// which is not the case here
let log_quietly = method == Method::GET;
async move {
if log_quietly {
debug!("Handling request");
} else {
info!("Handling request");
}
// Note that we reuse `error::handler` here and not returning and error at all,
// yet cannot use `!` directly in the method signature due to `routerify::RouterBuilder` limitation.
// Usage of the error handler also means that we expect only the `ApiError` errors to be raised in this call.
//
// Panics are not handled separately, there's a `tracing_panic_hook` from another module to do that globally.
match (self.0)(request).await {
Ok(response) => {
let response_status = response.status();
if log_quietly && response_status.is_success() {
debug!("Request handled, status: {response_status}");
} else {
info!("Request handled, status: {response_status}");
}
Ok(response)
}
Err(e) => Ok(error::handler(e.into()).await),
}
}
.instrument(request_span)
.await
if info.method() == Method::GET && res.status() == StatusCode::OK {
tracing::debug!(
"{} {} {} {}",
info.method(),
info.uri().path(),
request_id,
res.status()
);
} else {
tracing::info!(
"{} {} {} {}",
info.method(),
info.uri().path(),
request_id,
res.status()
);
}
Ok(res)
}
async fn prometheus_metrics_handler(_req: Request<Body>) -> Result<Response<Body>, ApiError> {
@@ -142,6 +96,12 @@ pub fn add_request_id_middleware<B: hyper::body::HttpBody + Send + Sync + 'stati
request_id.to_string()
}
};
if req.method() == Method::GET {
tracing::debug!("{} {} {}", req.method(), req.uri().path(), request_id);
} else {
tracing::info!("{} {} {}", req.method(), req.uri().path(), request_id);
}
req.set_context(RequestId(request_id));
Ok(req)
@@ -165,12 +125,11 @@ async fn add_request_id_header_to_response(
pub fn make_router() -> RouterBuilder<hyper::Body, ApiError> {
Router::builder()
.middleware(add_request_id_middleware())
.middleware(Middleware::post_with_info(logger))
.middleware(Middleware::post_with_info(
add_request_id_header_to_response,
))
.get("/metrics", |r| {
RequestSpan(prometheus_metrics_handler).handle(r)
})
.get("/metrics", prometheus_metrics_handler)
.err_handler(error::handler)
}
@@ -180,43 +139,40 @@ pub fn attach_openapi_ui(
spec_mount_path: &'static str,
ui_mount_path: &'static str,
) -> RouterBuilder<hyper::Body, ApiError> {
router_builder
.get(spec_mount_path, move |r| {
RequestSpan(move |_| async move { Ok(Response::builder().body(Body::from(spec)).unwrap()) })
.handle(r)
})
.get(ui_mount_path, move |r| RequestSpan( move |_| async move {
Ok(Response::builder().body(Body::from(format!(r#"
<!DOCTYPE html>
<html lang="en">
<head>
<title>rweb</title>
<link href="https://cdn.jsdelivr.net/npm/swagger-ui-dist@3/swagger-ui.css" rel="stylesheet">
</head>
<body>
<div id="swagger-ui"></div>
<script src="https://cdn.jsdelivr.net/npm/swagger-ui-dist@3/swagger-ui-bundle.js" charset="UTF-8"> </script>
<script>
window.onload = function() {{
const ui = SwaggerUIBundle({{
"dom_id": "\#swagger-ui",
presets: [
SwaggerUIBundle.presets.apis,
SwaggerUIBundle.SwaggerUIStandalonePreset
],
layout: "BaseLayout",
deepLinking: true,
showExtensions: true,
showCommonExtensions: true,
url: "{}",
}})
window.ui = ui;
}};
</script>
</body>
</html>
"#, spec_mount_path))).unwrap())
}).handle(r))
router_builder.get(spec_mount_path, move |_| async move {
Ok(Response::builder().body(Body::from(spec)).unwrap())
}).get(ui_mount_path, move |_| async move {
Ok(Response::builder().body(Body::from(format!(r#"
<!DOCTYPE html>
<html lang="en">
<head>
<title>rweb</title>
<link href="https://cdn.jsdelivr.net/npm/swagger-ui-dist@3/swagger-ui.css" rel="stylesheet">
</head>
<body>
<div id="swagger-ui"></div>
<script src="https://cdn.jsdelivr.net/npm/swagger-ui-dist@3/swagger-ui-bundle.js" charset="UTF-8"> </script>
<script>
window.onload = function() {{
const ui = SwaggerUIBundle({{
"dom_id": "\#swagger-ui",
presets: [
SwaggerUIBundle.presets.apis,
SwaggerUIBundle.SwaggerUIStandalonePreset
],
layout: "BaseLayout",
deepLinking: true,
showExtensions: true,
showCommonExtensions: true,
url: "{}",
}})
window.ui = ui;
}};
</script>
</body>
</html>
"#, spec_mount_path))).unwrap())
})
}
fn parse_token(header_value: &str) -> Result<&str, ApiError> {
@@ -278,7 +234,7 @@ where
async move {
let headers = response.headers_mut();
if headers.contains_key(&name) {
warn!(
tracing::warn!(
"{} response already contains header {:?}",
request_info.uri(),
&name,
@@ -318,7 +274,7 @@ pub fn serve_thread_main<S>(
where
S: Future<Output = ()> + Send + Sync,
{
info!("Starting an HTTP endpoint at {}", listener.local_addr()?);
tracing::info!("Starting an HTTP endpoint at {}", listener.local_addr()?);
// Create a Service from the router above to handle incoming requests.
let service = RouterService::new(router_builder.build().map_err(|err| anyhow!(err))?).unwrap();

View File

@@ -13,6 +13,8 @@ pub mod simple_rcu;
pub mod vec_map;
pub mod bin_ser;
pub mod postgres_backend;
pub mod postgres_backend_async;
// helper functions for creating and fsyncing
pub mod crashsafe;
@@ -25,6 +27,9 @@ pub mod id;
// http endpoint utils
pub mod http;
// socket splitting utils
pub mod sock_split;
// common log initialisation routine
pub mod logging;
@@ -49,8 +54,6 @@ pub mod fs_ext;
pub mod history_buffer;
pub mod measured_stream;
/// use with fail::cfg("$name", "return(2000)")
#[macro_export]
macro_rules! failpoint_sleep_millis_async {

View File

@@ -6,7 +6,6 @@ use std::ops::{Add, AddAssign};
use std::path::Path;
use std::str::FromStr;
use std::sync::atomic::{AtomicU64, Ordering};
use tracing::info;
use crate::seqwait::MonotonicCounter;
@@ -240,7 +239,6 @@ impl MonotonicCounter<Lsn> for RecordLsn {
let new_prev = self.last;
self.last = lsn;
self.prev = new_prev;
info!("advanced record lsn to {}/{}", self.last, self.prev);
}
fn cnt_value(&self) -> Lsn {
self.last

View File

@@ -1,77 +0,0 @@
use pin_project_lite::pin_project;
use std::pin::Pin;
use std::{io, task};
use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
pin_project! {
/// This stream tracks all writes and calls user provided
/// callback when the underlying stream is flushed.
pub struct MeasuredStream<S, R, W> {
#[pin]
stream: S,
write_count: usize,
inc_read_count: R,
inc_write_count: W,
}
}
impl<S, R, W> MeasuredStream<S, R, W> {
pub fn new(stream: S, inc_read_count: R, inc_write_count: W) -> Self {
Self {
stream,
write_count: 0,
inc_read_count,
inc_write_count,
}
}
}
impl<S: AsyncRead + Unpin, R: FnMut(usize), W> AsyncRead for MeasuredStream<S, R, W> {
fn poll_read(
self: Pin<&mut Self>,
context: &mut task::Context<'_>,
buf: &mut ReadBuf<'_>,
) -> task::Poll<io::Result<()>> {
let this = self.project();
let filled = buf.filled().len();
this.stream.poll_read(context, buf).map_ok(|()| {
let cnt = buf.filled().len() - filled;
// Increment the read count.
(this.inc_read_count)(cnt);
})
}
}
impl<S: AsyncWrite + Unpin, R, W: FnMut(usize)> AsyncWrite for MeasuredStream<S, R, W> {
fn poll_write(
self: Pin<&mut Self>,
context: &mut task::Context<'_>,
buf: &[u8],
) -> task::Poll<io::Result<usize>> {
let this = self.project();
this.stream.poll_write(context, buf).map_ok(|cnt| {
// Increment the write count.
*this.write_count += cnt;
cnt
})
}
fn poll_flush(
self: Pin<&mut Self>,
context: &mut task::Context<'_>,
) -> task::Poll<io::Result<()>> {
let this = self.project();
this.stream.poll_flush(context).map_ok(|()| {
// Call the user provided callback and reset the write count.
(this.inc_write_count)(*this.write_count);
*this.write_count = 0;
})
}
fn poll_shutdown(
self: Pin<&mut Self>,
context: &mut task::Context<'_>,
) -> task::Poll<io::Result<()>> {
self.project().stream.poll_shutdown(context)
}
}

View File

@@ -0,0 +1,485 @@
//! Server-side synchronous Postgres connection, as limited as we need.
//! To use, create PostgresBackend and run() it, passing the Handler
//! implementation determining how to process the queries. Currently its API
//! is rather narrow, but we can extend it once required.
use crate::postgres_backend_async::{log_query_error, short_error, QueryError};
use crate::sock_split::{BidiStream, ReadStream, WriteStream};
use anyhow::Context;
use bytes::{Bytes, BytesMut};
use pq_proto::{BeMessage, FeMessage, FeStartupPacket};
use serde::{Deserialize, Serialize};
use std::fmt;
use std::io::{self, Write};
use std::net::{Shutdown, SocketAddr, TcpStream};
use std::str::FromStr;
use std::sync::Arc;
use std::time::Duration;
use tracing::*;
pub trait Handler {
/// Handle single query.
/// postgres_backend will issue ReadyForQuery after calling this (this
/// might be not what we want after CopyData streaming, but currently we don't
/// care).
fn process_query(
&mut self,
pgb: &mut PostgresBackend,
query_string: &str,
) -> Result<(), QueryError>;
/// Called on startup packet receival, allows to process params.
///
/// If Ok(false) is returned postgres_backend will skip auth -- that is needed for new users
/// creation is the proxy code. That is quite hacky and ad-hoc solution, may be we could allow
/// to override whole init logic in implementations.
fn startup(
&mut self,
_pgb: &mut PostgresBackend,
_sm: &FeStartupPacket,
) -> Result<(), QueryError> {
Ok(())
}
/// Check auth jwt
fn check_auth_jwt(
&mut self,
_pgb: &mut PostgresBackend,
_jwt_response: &[u8],
) -> Result<(), QueryError> {
Err(QueryError::Other(anyhow::anyhow!("JWT auth failed")))
}
fn is_shutdown_requested(&self) -> bool {
false
}
}
/// PostgresBackend protocol state.
/// XXX: The order of the constructors matters.
#[derive(Clone, Copy, PartialEq, Eq, PartialOrd)]
pub enum ProtoState {
Initialization,
Encrypted,
Authentication,
Established,
}
#[derive(Debug, PartialEq, Eq, Clone, Copy, Serialize, Deserialize)]
pub enum AuthType {
Trust,
// This mimics postgres's AuthenticationCleartextPassword but instead of password expects JWT
NeonJWT,
}
impl FromStr for AuthType {
type Err = anyhow::Error;
fn from_str(s: &str) -> Result<Self, Self::Err> {
match s {
"Trust" => Ok(Self::Trust),
"NeonJWT" => Ok(Self::NeonJWT),
_ => anyhow::bail!("invalid value \"{s}\" for auth type"),
}
}
}
impl fmt::Display for AuthType {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.write_str(match self {
AuthType::Trust => "Trust",
AuthType::NeonJWT => "NeonJWT",
})
}
}
#[derive(Clone, Copy)]
pub enum ProcessMsgResult {
Continue,
Break,
}
/// Always-writeable sock_split stream.
/// May not be readable. See [`PostgresBackend::take_stream_in`]
pub enum Stream {
Bidirectional(BidiStream),
WriteOnly(WriteStream),
}
impl Stream {
fn shutdown(&mut self, how: Shutdown) -> io::Result<()> {
match self {
Self::Bidirectional(bidi_stream) => bidi_stream.shutdown(how),
Self::WriteOnly(write_stream) => write_stream.shutdown(how),
}
}
}
impl io::Write for Stream {
fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
match self {
Self::Bidirectional(bidi_stream) => bidi_stream.write(buf),
Self::WriteOnly(write_stream) => write_stream.write(buf),
}
}
fn flush(&mut self) -> io::Result<()> {
match self {
Self::Bidirectional(bidi_stream) => bidi_stream.flush(),
Self::WriteOnly(write_stream) => write_stream.flush(),
}
}
}
pub struct PostgresBackend {
stream: Option<Stream>,
// Output buffer. c.f. BeMessage::write why we are using BytesMut here.
buf_out: BytesMut,
pub state: ProtoState,
auth_type: AuthType,
peer_addr: SocketAddr,
pub tls_config: Option<Arc<rustls::ServerConfig>>,
}
pub fn query_from_cstring(query_string: Bytes) -> Vec<u8> {
let mut query_string = query_string.to_vec();
if let Some(ch) = query_string.last() {
if *ch == 0 {
query_string.pop();
}
}
query_string
}
// Helper function for socket read loops
pub fn is_socket_read_timed_out(error: &anyhow::Error) -> bool {
for cause in error.chain() {
if let Some(io_error) = cause.downcast_ref::<io::Error>() {
if io_error.kind() == std::io::ErrorKind::WouldBlock {
return true;
}
}
}
false
}
// Cast a byte slice to a string slice, dropping null terminator if there's one.
fn cstr_to_str(bytes: &[u8]) -> anyhow::Result<&str> {
let without_null = bytes.strip_suffix(&[0]).unwrap_or(bytes);
std::str::from_utf8(without_null).map_err(|e| e.into())
}
impl PostgresBackend {
pub fn new(
socket: TcpStream,
auth_type: AuthType,
tls_config: Option<Arc<rustls::ServerConfig>>,
set_read_timeout: bool,
) -> io::Result<Self> {
let peer_addr = socket.peer_addr()?;
if set_read_timeout {
socket
.set_read_timeout(Some(Duration::from_secs(5)))
.unwrap();
}
Ok(Self {
stream: Some(Stream::Bidirectional(BidiStream::from_tcp(socket))),
buf_out: BytesMut::with_capacity(10 * 1024),
state: ProtoState::Initialization,
auth_type,
tls_config,
peer_addr,
})
}
pub fn into_stream(self) -> Stream {
self.stream.unwrap()
}
/// Get direct reference (into the Option) to the read stream.
fn get_stream_in(&mut self) -> anyhow::Result<&mut BidiStream> {
match &mut self.stream {
Some(Stream::Bidirectional(stream)) => Ok(stream),
_ => anyhow::bail!("reader taken"),
}
}
pub fn get_peer_addr(&self) -> &SocketAddr {
&self.peer_addr
}
pub fn take_stream_in(&mut self) -> Option<ReadStream> {
let stream = self.stream.take();
match stream {
Some(Stream::Bidirectional(bidi_stream)) => {
let (read, write) = bidi_stream.split();
self.stream = Some(Stream::WriteOnly(write));
Some(read)
}
stream => {
self.stream = stream;
None
}
}
}
/// Read full message or return None if connection is closed.
pub fn read_message(&mut self) -> Result<Option<FeMessage>, QueryError> {
let (state, stream) = (self.state, self.get_stream_in()?);
use ProtoState::*;
match state {
Initialization | Encrypted => FeStartupPacket::read(stream),
Authentication | Established => FeMessage::read(stream),
}
.map_err(QueryError::from)
}
/// Write message into internal output buffer.
pub fn write_message_noflush(&mut self, message: &BeMessage) -> io::Result<&mut Self> {
BeMessage::write(&mut self.buf_out, message)?;
Ok(self)
}
/// Flush output buffer into the socket.
pub fn flush(&mut self) -> io::Result<&mut Self> {
let stream = self.stream.as_mut().unwrap();
stream.write_all(&self.buf_out)?;
self.buf_out.clear();
Ok(self)
}
/// Write message into internal buffer and flush it.
pub fn write_message(&mut self, message: &BeMessage) -> io::Result<&mut Self> {
self.write_message_noflush(message)?;
self.flush()
}
// Wrapper for run_message_loop() that shuts down socket when we are done
pub fn run(mut self, handler: &mut impl Handler) -> Result<(), QueryError> {
let ret = self.run_message_loop(handler);
if let Some(stream) = self.stream.as_mut() {
let _ = stream.shutdown(Shutdown::Both);
}
ret
}
fn run_message_loop(&mut self, handler: &mut impl Handler) -> Result<(), QueryError> {
trace!("postgres backend to {:?} started", self.peer_addr);
let mut unnamed_query_string = Bytes::new();
while !handler.is_shutdown_requested() {
match self.read_message() {
Ok(message) => {
if let Some(msg) = message {
trace!("got message {msg:?}");
match self.process_message(handler, msg, &mut unnamed_query_string)? {
ProcessMsgResult::Continue => continue,
ProcessMsgResult::Break => break,
}
} else {
break;
}
}
Err(e) => {
if let QueryError::Other(e) = &e {
if is_socket_read_timed_out(e) {
continue;
}
}
return Err(e);
}
}
}
trace!("postgres backend to {:?} exited", self.peer_addr);
Ok(())
}
pub fn start_tls(&mut self) -> anyhow::Result<()> {
match self.stream.take() {
Some(Stream::Bidirectional(bidi_stream)) => {
let conn = rustls::ServerConnection::new(self.tls_config.clone().unwrap())?;
self.stream = Some(Stream::Bidirectional(bidi_stream.start_tls(conn)?));
Ok(())
}
stream => {
self.stream = stream;
anyhow::bail!("can't start TLs without bidi stream");
}
}
}
fn process_message(
&mut self,
handler: &mut impl Handler,
msg: FeMessage,
unnamed_query_string: &mut Bytes,
) -> Result<ProcessMsgResult, QueryError> {
// Allow only startup and password messages during auth. Otherwise client would be able to bypass auth
// TODO: change that to proper top-level match of protocol state with separate message handling for each state
if self.state < ProtoState::Established
&& !matches!(
msg,
FeMessage::PasswordMessage(_) | FeMessage::StartupPacket(_)
)
{
return Err(QueryError::Other(anyhow::anyhow!("protocol violation")));
}
let have_tls = self.tls_config.is_some();
match msg {
FeMessage::StartupPacket(m) => {
trace!("got startup message {m:?}");
match m {
FeStartupPacket::SslRequest => {
debug!("SSL requested");
self.write_message(&BeMessage::EncryptionResponse(have_tls))?;
if have_tls {
self.start_tls()?;
self.state = ProtoState::Encrypted;
}
}
FeStartupPacket::GssEncRequest => {
debug!("GSS requested");
self.write_message(&BeMessage::EncryptionResponse(false))?;
}
FeStartupPacket::StartupMessage { .. } => {
if have_tls && !matches!(self.state, ProtoState::Encrypted) {
self.write_message(&BeMessage::ErrorResponse(
"must connect with TLS",
None,
))?;
return Err(QueryError::Other(anyhow::anyhow!(
"client did not connect with TLS"
)));
}
// NB: startup() may change self.auth_type -- we are using that in proxy code
// to bypass auth for new users.
handler.startup(self, &m)?;
match self.auth_type {
AuthType::Trust => {
self.write_message_noflush(&BeMessage::AuthenticationOk)?
.write_message_noflush(&BeMessage::CLIENT_ENCODING)?
// The async python driver requires a valid server_version
.write_message_noflush(&BeMessage::server_version("14.1"))?
.write_message(&BeMessage::ReadyForQuery)?;
self.state = ProtoState::Established;
}
AuthType::NeonJWT => {
self.write_message(&BeMessage::AuthenticationCleartextPassword)?;
self.state = ProtoState::Authentication;
}
}
}
FeStartupPacket::CancelRequest { .. } => {
return Ok(ProcessMsgResult::Break);
}
}
}
FeMessage::PasswordMessage(m) => {
trace!("got password message '{:?}'", m);
assert!(self.state == ProtoState::Authentication);
match self.auth_type {
AuthType::Trust => unreachable!(),
AuthType::NeonJWT => {
let (_, jwt_response) = m.split_last().context("protocol violation")?;
if let Err(e) = handler.check_auth_jwt(self, jwt_response) {
self.write_message(&BeMessage::ErrorResponse(
&e.to_string(),
Some(e.pg_error_code()),
))?;
return Err(e);
}
}
}
self.write_message_noflush(&BeMessage::AuthenticationOk)?
.write_message_noflush(&BeMessage::CLIENT_ENCODING)?
.write_message(&BeMessage::ReadyForQuery)?;
self.state = ProtoState::Established;
}
FeMessage::Query(body) => {
// remove null terminator
let query_string = cstr_to_str(&body)?;
trace!("got query {query_string:?}");
if let Err(e) = handler.process_query(self, query_string) {
log_query_error(query_string, &e);
let short_error = short_error(&e);
self.write_message_noflush(&BeMessage::ErrorResponse(
&short_error,
Some(e.pg_error_code()),
))?;
}
self.write_message(&BeMessage::ReadyForQuery)?;
}
FeMessage::Parse(m) => {
*unnamed_query_string = m.query_string;
self.write_message(&BeMessage::ParseComplete)?;
}
FeMessage::Describe(_) => {
self.write_message_noflush(&BeMessage::ParameterDescription)?
.write_message(&BeMessage::NoData)?;
}
FeMessage::Bind(_) => {
self.write_message(&BeMessage::BindComplete)?;
}
FeMessage::Close(_) => {
self.write_message(&BeMessage::CloseComplete)?;
}
FeMessage::Execute(_) => {
let query_string = cstr_to_str(unnamed_query_string)?;
trace!("got execute {query_string:?}");
if let Err(e) = handler.process_query(self, query_string) {
log_query_error(query_string, &e);
self.write_message(&BeMessage::ErrorResponse(
&e.to_string(),
Some(e.pg_error_code()),
))?;
}
// NOTE there is no ReadyForQuery message. This handler is used
// for basebackup and it uses CopyOut which doesn't require
// ReadyForQuery message and backend just switches back to
// processing mode after sending CopyDone or ErrorResponse.
}
FeMessage::Sync => {
self.write_message(&BeMessage::ReadyForQuery)?;
}
FeMessage::Terminate => {
return Ok(ProcessMsgResult::Break);
}
// We prefer explicit pattern matching to wildcards, because
// this helps us spot the places where new variants are missing
FeMessage::CopyData(_) | FeMessage::CopyDone | FeMessage::CopyFail => {
return Err(QueryError::Other(anyhow::anyhow!(
"unexpected message type: {msg:?}"
)));
}
}
Ok(ProcessMsgResult::Continue)
}
}

View File

@@ -0,0 +1,634 @@
//! Server-side asynchronous Postgres connection, as limited as we need.
//! To use, create PostgresBackend and run() it, passing the Handler
//! implementation determining how to process the queries. Currently its API
//! is rather narrow, but we can extend it once required.
use crate::postgres_backend::AuthType;
use anyhow::Context;
use bytes::{Buf, Bytes, BytesMut};
use pq_proto::{BeMessage, ConnectionError, FeMessage, FeStartupPacket, SQLSTATE_INTERNAL_ERROR};
use std::io;
use std::net::SocketAddr;
use std::pin::Pin;
use std::sync::Arc;
use std::task::Poll;
use std::{future::Future, task::ready};
use tracing::{debug, error, info, trace};
use tokio::io::{AsyncRead, AsyncWrite, AsyncWriteExt, BufReader};
use tokio_rustls::TlsAcceptor;
pub fn is_expected_io_error(e: &io::Error) -> bool {
use io::ErrorKind::*;
matches!(
e.kind(),
ConnectionRefused | ConnectionAborted | ConnectionReset
)
}
/// An error, occurred during query processing:
/// either during the connection ([`ConnectionError`]) or before/after it.
#[derive(thiserror::Error, Debug)]
pub enum QueryError {
/// The connection was lost while processing the query.
#[error(transparent)]
Disconnected(#[from] ConnectionError),
/// Some other error
#[error(transparent)]
Other(#[from] anyhow::Error),
}
impl From<io::Error> for QueryError {
fn from(e: io::Error) -> Self {
Self::Disconnected(ConnectionError::Socket(e))
}
}
impl QueryError {
pub fn pg_error_code(&self) -> &'static [u8; 5] {
match self {
Self::Disconnected(_) => b"08006", // connection failure
Self::Other(_) => SQLSTATE_INTERNAL_ERROR, // internal error
}
}
}
#[async_trait::async_trait]
pub trait Handler {
/// Handle single query.
/// postgres_backend will issue ReadyForQuery after calling this (this
/// might be not what we want after CopyData streaming, but currently we don't
/// care).
async fn process_query(
&mut self,
pgb: &mut PostgresBackend,
query_string: &str,
) -> Result<(), QueryError>;
/// Called on startup packet receival, allows to process params.
///
/// If Ok(false) is returned postgres_backend will skip auth -- that is needed for new users
/// creation is the proxy code. That is quite hacky and ad-hoc solution, may be we could allow
/// to override whole init logic in implementations.
fn startup(
&mut self,
_pgb: &mut PostgresBackend,
_sm: &FeStartupPacket,
) -> Result<(), QueryError> {
Ok(())
}
/// Check auth jwt
fn check_auth_jwt(
&mut self,
_pgb: &mut PostgresBackend,
_jwt_response: &[u8],
) -> Result<(), QueryError> {
Err(QueryError::Other(anyhow::anyhow!("JWT auth failed")))
}
}
/// PostgresBackend protocol state.
/// XXX: The order of the constructors matters.
#[derive(Clone, Copy, PartialEq, Eq, PartialOrd)]
pub enum ProtoState {
Initialization,
Encrypted,
Authentication,
Established,
Closed,
}
#[derive(Clone, Copy)]
pub enum ProcessMsgResult {
Continue,
Break,
}
/// Always-writeable sock_split stream.
/// May not be readable. See [`PostgresBackend::take_stream_in`]
pub enum Stream {
Unencrypted(BufReader<tokio::net::TcpStream>),
Tls(Box<tokio_rustls::server::TlsStream<BufReader<tokio::net::TcpStream>>>),
Broken,
}
impl AsyncWrite for Stream {
fn poll_write(
self: Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
buf: &[u8],
) -> Poll<io::Result<usize>> {
match self.get_mut() {
Self::Unencrypted(stream) => Pin::new(stream).poll_write(cx, buf),
Self::Tls(stream) => Pin::new(stream).poll_write(cx, buf),
Self::Broken => unreachable!(),
}
}
fn poll_flush(self: Pin<&mut Self>, cx: &mut std::task::Context<'_>) -> Poll<io::Result<()>> {
match self.get_mut() {
Self::Unencrypted(stream) => Pin::new(stream).poll_flush(cx),
Self::Tls(stream) => Pin::new(stream).poll_flush(cx),
Self::Broken => unreachable!(),
}
}
fn poll_shutdown(
self: Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
) -> Poll<io::Result<()>> {
match self.get_mut() {
Self::Unencrypted(stream) => Pin::new(stream).poll_shutdown(cx),
Self::Tls(stream) => Pin::new(stream).poll_shutdown(cx),
Self::Broken => unreachable!(),
}
}
}
impl AsyncRead for Stream {
fn poll_read(
self: Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
buf: &mut tokio::io::ReadBuf<'_>,
) -> Poll<io::Result<()>> {
match self.get_mut() {
Self::Unencrypted(stream) => Pin::new(stream).poll_read(cx, buf),
Self::Tls(stream) => Pin::new(stream).poll_read(cx, buf),
Self::Broken => unreachable!(),
}
}
}
pub struct PostgresBackend {
stream: Stream,
// Output buffer. c.f. BeMessage::write why we are using BytesMut here.
// The data between 0 and "current position" as tracked by the bytes::Buf
// implementation of BytesMut, have already been written.
buf_out: BytesMut,
pub state: ProtoState,
auth_type: AuthType,
peer_addr: SocketAddr,
pub tls_config: Option<Arc<rustls::ServerConfig>>,
}
pub fn query_from_cstring(query_string: Bytes) -> Vec<u8> {
let mut query_string = query_string.to_vec();
if let Some(ch) = query_string.last() {
if *ch == 0 {
query_string.pop();
}
}
query_string
}
// Cast a byte slice to a string slice, dropping null terminator if there's one.
fn cstr_to_str(bytes: &[u8]) -> anyhow::Result<&str> {
let without_null = bytes.strip_suffix(&[0]).unwrap_or(bytes);
std::str::from_utf8(without_null).map_err(|e| e.into())
}
impl PostgresBackend {
pub fn new(
socket: tokio::net::TcpStream,
auth_type: AuthType,
tls_config: Option<Arc<rustls::ServerConfig>>,
) -> io::Result<Self> {
let peer_addr = socket.peer_addr()?;
Ok(Self {
stream: Stream::Unencrypted(BufReader::new(socket)),
buf_out: BytesMut::with_capacity(10 * 1024),
state: ProtoState::Initialization,
auth_type,
tls_config,
peer_addr,
})
}
pub fn get_peer_addr(&self) -> &SocketAddr {
&self.peer_addr
}
/// Read full message or return None if connection is closed.
pub async fn read_message(&mut self) -> Result<Option<FeMessage>, QueryError> {
use ProtoState::*;
match self.state {
Initialization | Encrypted => FeStartupPacket::read_fut(&mut self.stream).await,
Authentication | Established => FeMessage::read_fut(&mut self.stream).await,
Closed => Ok(None),
}
.map_err(QueryError::from)
}
/// Flush output buffer into the socket.
pub async fn flush(&mut self) -> io::Result<()> {
while self.buf_out.has_remaining() {
let bytes_written = self.stream.write(self.buf_out.chunk()).await?;
self.buf_out.advance(bytes_written);
}
self.buf_out.clear();
Ok(())
}
/// Write message into internal output buffer.
pub fn write_message(&mut self, message: &BeMessage<'_>) -> io::Result<&mut Self> {
BeMessage::write(&mut self.buf_out, message)?;
Ok(self)
}
/// Returns an AsyncWrite implementation that wraps all the data written
/// to it in CopyData messages, and writes them to the connection
///
/// The caller is responsible for sending CopyOutResponse and CopyDone messages.
pub fn copyout_writer(&mut self) -> CopyDataWriter {
CopyDataWriter { pgb: self }
}
/// A polling function that tries to write all the data from 'buf_out' to the
/// underlying stream.
fn poll_write_buf(
&mut self,
cx: &mut std::task::Context<'_>,
) -> Poll<Result<(), std::io::Error>> {
while self.buf_out.has_remaining() {
match ready!(Pin::new(&mut self.stream).poll_write(cx, self.buf_out.chunk())) {
Ok(bytes_written) => self.buf_out.advance(bytes_written),
Err(err) => return Poll::Ready(Err(err)),
}
}
Poll::Ready(Ok(()))
}
fn poll_flush(&mut self, cx: &mut std::task::Context<'_>) -> Poll<Result<(), std::io::Error>> {
Pin::new(&mut self.stream).poll_flush(cx)
}
// Wrapper for run_message_loop() that shuts down socket when we are done
pub async fn run<F, S>(
mut self,
handler: &mut impl Handler,
shutdown_watcher: F,
) -> Result<(), QueryError>
where
F: Fn() -> S,
S: Future,
{
let ret = self.run_message_loop(handler, shutdown_watcher).await;
let _ = self.stream.shutdown();
ret
}
async fn run_message_loop<F, S>(
&mut self,
handler: &mut impl Handler,
shutdown_watcher: F,
) -> Result<(), QueryError>
where
F: Fn() -> S,
S: Future,
{
trace!("postgres backend to {:?} started", self.peer_addr);
tokio::select!(
biased;
_ = shutdown_watcher() => {
// We were requested to shut down.
tracing::info!("shutdown request received during handshake");
return Ok(())
},
result = async {
while self.state < ProtoState::Established {
if let Some(msg) = self.read_message().await? {
trace!("got message {msg:?} during handshake");
match self.process_handshake_message(handler, msg).await? {
ProcessMsgResult::Continue => {
self.flush().await?;
continue;
}
ProcessMsgResult::Break => {
trace!("postgres backend to {:?} exited during handshake", self.peer_addr);
return Ok(());
}
}
} else {
trace!("postgres backend to {:?} exited during handshake", self.peer_addr);
return Ok(());
}
}
Ok::<(), QueryError>(())
} => {
// Handshake complete.
result?;
}
);
// Authentication completed
let mut query_string = Bytes::new();
while let Some(msg) = tokio::select!(
biased;
_ = shutdown_watcher() => {
// We were requested to shut down.
tracing::info!("shutdown request received in run_message_loop");
Ok(None)
},
msg = self.read_message() => { msg },
)? {
trace!("got message {:?}", msg);
let result = self.process_message(handler, msg, &mut query_string).await;
self.flush().await?;
match result? {
ProcessMsgResult::Continue => {
self.flush().await?;
continue;
}
ProcessMsgResult::Break => break,
}
}
trace!("postgres backend to {:?} exited", self.peer_addr);
Ok(())
}
async fn start_tls(&mut self) -> anyhow::Result<()> {
if let Stream::Unencrypted(plain_stream) =
std::mem::replace(&mut self.stream, Stream::Broken)
{
let acceptor = TlsAcceptor::from(self.tls_config.clone().unwrap());
let tls_stream = acceptor.accept(plain_stream).await?;
self.stream = Stream::Tls(Box::new(tls_stream));
return Ok(());
};
anyhow::bail!("TLS already started");
}
async fn process_handshake_message(
&mut self,
handler: &mut impl Handler,
msg: FeMessage,
) -> Result<ProcessMsgResult, QueryError> {
assert!(self.state < ProtoState::Established);
let have_tls = self.tls_config.is_some();
match msg {
FeMessage::StartupPacket(m) => {
trace!("got startup message {m:?}");
match m {
FeStartupPacket::SslRequest => {
debug!("SSL requested");
self.write_message(&BeMessage::EncryptionResponse(have_tls))?;
if have_tls {
self.start_tls().await?;
self.state = ProtoState::Encrypted;
}
}
FeStartupPacket::GssEncRequest => {
debug!("GSS requested");
self.write_message(&BeMessage::EncryptionResponse(false))?;
}
FeStartupPacket::StartupMessage { .. } => {
if have_tls && !matches!(self.state, ProtoState::Encrypted) {
self.write_message(&BeMessage::ErrorResponse(
"must connect with TLS",
None,
))?;
return Err(QueryError::Other(anyhow::anyhow!(
"client did not connect with TLS"
)));
}
// NB: startup() may change self.auth_type -- we are using that in proxy code
// to bypass auth for new users.
handler.startup(self, &m)?;
match self.auth_type {
AuthType::Trust => {
self.write_message(&BeMessage::AuthenticationOk)?
.write_message(&BeMessage::CLIENT_ENCODING)?
// The async python driver requires a valid server_version
.write_message(&BeMessage::server_version("14.1"))?
.write_message(&BeMessage::ReadyForQuery)?;
self.state = ProtoState::Established;
}
AuthType::NeonJWT => {
self.write_message(&BeMessage::AuthenticationCleartextPassword)?;
self.state = ProtoState::Authentication;
}
}
}
FeStartupPacket::CancelRequest { .. } => {
self.state = ProtoState::Closed;
return Ok(ProcessMsgResult::Break);
}
}
}
FeMessage::PasswordMessage(m) => {
trace!("got password message '{:?}'", m);
assert!(self.state == ProtoState::Authentication);
match self.auth_type {
AuthType::Trust => unreachable!(),
AuthType::NeonJWT => {
let (_, jwt_response) = m.split_last().context("protocol violation")?;
if let Err(e) = handler.check_auth_jwt(self, jwt_response) {
self.write_message(&BeMessage::ErrorResponse(
&e.to_string(),
Some(e.pg_error_code()),
))?;
return Err(e);
}
}
}
self.write_message(&BeMessage::AuthenticationOk)?
.write_message(&BeMessage::CLIENT_ENCODING)?
.write_message(&BeMessage::ReadyForQuery)?;
self.state = ProtoState::Established;
}
_ => {
self.state = ProtoState::Closed;
return Ok(ProcessMsgResult::Break);
}
}
Ok(ProcessMsgResult::Continue)
}
async fn process_message(
&mut self,
handler: &mut impl Handler,
msg: FeMessage,
unnamed_query_string: &mut Bytes,
) -> Result<ProcessMsgResult, QueryError> {
// Allow only startup and password messages during auth. Otherwise client would be able to bypass auth
// TODO: change that to proper top-level match of protocol state with separate message handling for each state
assert!(self.state == ProtoState::Established);
match msg {
FeMessage::StartupPacket(_) | FeMessage::PasswordMessage(_) => {
return Err(QueryError::Other(anyhow::anyhow!("protocol violation")));
}
FeMessage::Query(body) => {
// remove null terminator
let query_string = cstr_to_str(&body)?;
trace!("got query {query_string:?}");
if let Err(e) = handler.process_query(self, query_string).await {
log_query_error(query_string, &e);
let short_error = short_error(&e);
self.write_message(&BeMessage::ErrorResponse(
&short_error,
Some(e.pg_error_code()),
))?;
}
self.write_message(&BeMessage::ReadyForQuery)?;
}
FeMessage::Parse(m) => {
*unnamed_query_string = m.query_string;
self.write_message(&BeMessage::ParseComplete)?;
}
FeMessage::Describe(_) => {
self.write_message(&BeMessage::ParameterDescription)?
.write_message(&BeMessage::NoData)?;
}
FeMessage::Bind(_) => {
self.write_message(&BeMessage::BindComplete)?;
}
FeMessage::Close(_) => {
self.write_message(&BeMessage::CloseComplete)?;
}
FeMessage::Execute(_) => {
let query_string = cstr_to_str(unnamed_query_string)?;
trace!("got execute {query_string:?}");
if let Err(e) = handler.process_query(self, query_string).await {
log_query_error(query_string, &e);
self.write_message(&BeMessage::ErrorResponse(
&e.to_string(),
Some(e.pg_error_code()),
))?;
}
// NOTE there is no ReadyForQuery message. This handler is used
// for basebackup and it uses CopyOut which doesn't require
// ReadyForQuery message and backend just switches back to
// processing mode after sending CopyDone or ErrorResponse.
}
FeMessage::Sync => {
self.write_message(&BeMessage::ReadyForQuery)?;
}
FeMessage::Terminate => {
return Ok(ProcessMsgResult::Break);
}
// We prefer explicit pattern matching to wildcards, because
// this helps us spot the places where new variants are missing
FeMessage::CopyData(_) | FeMessage::CopyDone | FeMessage::CopyFail => {
return Err(QueryError::Other(anyhow::anyhow!(
"unexpected message type: {:?}",
msg
)));
}
}
Ok(ProcessMsgResult::Continue)
}
}
///
/// A futures::AsyncWrite implementation that wraps all data written to it in CopyData
/// messages.
///
pub struct CopyDataWriter<'a> {
pgb: &'a mut PostgresBackend,
}
impl<'a> AsyncWrite for CopyDataWriter<'a> {
fn poll_write(
self: Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
buf: &[u8],
) -> Poll<Result<usize, std::io::Error>> {
let this = self.get_mut();
// It's not strictly required to flush between each message, but makes it easier
// to view in wireshark, and usually the messages that the callers write are
// decently-sized anyway.
match ready!(this.pgb.poll_write_buf(cx)) {
Ok(()) => {}
Err(err) => return Poll::Ready(Err(err)),
}
// CopyData
// XXX: if the input is large, we should split it into multiple messages.
// Not sure what the threshold should be, but the ultimate hard limit is that
// the length cannot exceed u32.
this.pgb.write_message(&BeMessage::CopyData(buf))?;
Poll::Ready(Ok(buf.len()))
}
fn poll_flush(
self: Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
) -> Poll<Result<(), std::io::Error>> {
let this = self.get_mut();
match ready!(this.pgb.poll_write_buf(cx)) {
Ok(()) => {}
Err(err) => return Poll::Ready(Err(err)),
}
this.pgb.poll_flush(cx)
}
fn poll_shutdown(
self: Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
) -> Poll<Result<(), std::io::Error>> {
let this = self.get_mut();
match ready!(this.pgb.poll_write_buf(cx)) {
Ok(()) => {}
Err(err) => return Poll::Ready(Err(err)),
}
this.pgb.poll_flush(cx)
}
}
pub fn short_error(e: &QueryError) -> String {
match e {
QueryError::Disconnected(connection_error) => connection_error.to_string(),
QueryError::Other(e) => format!("{e:#}"),
}
}
pub(super) fn log_query_error(query: &str, e: &QueryError) {
match e {
QueryError::Disconnected(ConnectionError::Socket(io_error)) => {
if is_expected_io_error(io_error) {
info!("query handler for '{query}' failed with expected io error: {io_error}");
} else {
error!("query handler for '{query}' failed with io error: {io_error}");
}
}
QueryError::Disconnected(other_connection_error) => {
error!("query handler for '{query}' failed with connection error: {other_connection_error:?}")
}
QueryError::Other(e) => {
error!("query handler for '{query}' failed: {e:?}");
}
}
}

View File

@@ -0,0 +1,206 @@
use std::{
io::{self, BufReader, Write},
net::{Shutdown, TcpStream},
sync::Arc,
};
use rustls::Connection;
/// Wrapper supporting reads of a shared TcpStream.
pub struct ArcTcpRead(Arc<TcpStream>);
impl io::Read for ArcTcpRead {
fn read(&mut self, buf: &mut [u8]) -> std::io::Result<usize> {
(&*self.0).read(buf)
}
}
impl std::ops::Deref for ArcTcpRead {
type Target = TcpStream;
fn deref(&self) -> &Self::Target {
self.0.deref()
}
}
/// Wrapper around a TCP Stream supporting buffered reads.
pub struct BufStream(BufReader<ArcTcpRead>);
impl io::Read for BufStream {
fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
self.0.read(buf)
}
}
impl io::Write for BufStream {
fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
self.get_ref().write(buf)
}
fn flush(&mut self) -> io::Result<()> {
self.get_ref().flush()
}
}
impl BufStream {
/// Unwrap into the internal BufReader.
fn into_reader(self) -> BufReader<ArcTcpRead> {
self.0
}
/// Returns a reference to the underlying TcpStream.
fn get_ref(&self) -> &TcpStream {
&self.0.get_ref().0
}
}
pub enum ReadStream {
Tcp(BufReader<ArcTcpRead>),
Tls(rustls_split::ReadHalf),
}
impl io::Read for ReadStream {
fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
match self {
Self::Tcp(reader) => reader.read(buf),
Self::Tls(read_half) => read_half.read(buf),
}
}
}
impl ReadStream {
pub fn shutdown(&mut self, how: Shutdown) -> io::Result<()> {
match self {
Self::Tcp(stream) => stream.get_ref().shutdown(how),
Self::Tls(write_half) => write_half.shutdown(how),
}
}
}
pub enum WriteStream {
Tcp(Arc<TcpStream>),
Tls(rustls_split::WriteHalf),
}
impl WriteStream {
pub fn shutdown(&mut self, how: Shutdown) -> io::Result<()> {
match self {
Self::Tcp(stream) => stream.shutdown(how),
Self::Tls(write_half) => write_half.shutdown(how),
}
}
}
impl io::Write for WriteStream {
fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
match self {
Self::Tcp(stream) => stream.as_ref().write(buf),
Self::Tls(write_half) => write_half.write(buf),
}
}
fn flush(&mut self) -> io::Result<()> {
match self {
Self::Tcp(stream) => stream.as_ref().flush(),
Self::Tls(write_half) => write_half.flush(),
}
}
}
type TlsStream<T> = rustls::StreamOwned<rustls::ServerConnection, T>;
pub enum BidiStream {
Tcp(BufStream),
/// This variant is boxed, because [`rustls::ServerConnection`] is quite larger than [`BufStream`].
Tls(Box<TlsStream<BufStream>>),
}
impl BidiStream {
pub fn from_tcp(stream: TcpStream) -> Self {
Self::Tcp(BufStream(BufReader::new(ArcTcpRead(Arc::new(stream)))))
}
pub fn shutdown(&mut self, how: Shutdown) -> io::Result<()> {
match self {
Self::Tcp(stream) => stream.get_ref().shutdown(how),
Self::Tls(tls_boxed) => {
if how == Shutdown::Read {
tls_boxed.sock.get_ref().shutdown(how)
} else {
tls_boxed.conn.send_close_notify();
let res = tls_boxed.flush();
tls_boxed.sock.get_ref().shutdown(how)?;
res
}
}
}
}
/// Split the bi-directional stream into two owned read and write halves.
pub fn split(self) -> (ReadStream, WriteStream) {
match self {
Self::Tcp(stream) => {
let reader = stream.into_reader();
let stream: Arc<TcpStream> = reader.get_ref().0.clone();
(ReadStream::Tcp(reader), WriteStream::Tcp(stream))
}
Self::Tls(tls_boxed) => {
let reader = tls_boxed.sock.into_reader();
let buffer_data = reader.buffer().to_owned();
let read_buf_cfg = rustls_split::BufCfg::with_data(buffer_data, 8192);
let write_buf_cfg = rustls_split::BufCfg::with_capacity(8192);
// TODO would be nice to avoid the Arc here
let socket = Arc::try_unwrap(reader.into_inner().0).unwrap();
let (read_half, write_half) = rustls_split::split(
socket,
Connection::Server(tls_boxed.conn),
read_buf_cfg,
write_buf_cfg,
);
(ReadStream::Tls(read_half), WriteStream::Tls(write_half))
}
}
}
pub fn start_tls(self, mut conn: rustls::ServerConnection) -> io::Result<Self> {
match self {
Self::Tcp(mut stream) => {
conn.complete_io(&mut stream)?;
assert!(!conn.is_handshaking());
Ok(Self::Tls(Box::new(TlsStream::new(conn, stream))))
}
Self::Tls { .. } => Err(io::Error::new(
io::ErrorKind::InvalidInput,
"TLS is already started on this stream",
)),
}
}
}
impl io::Read for BidiStream {
fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
match self {
Self::Tcp(stream) => stream.read(buf),
Self::Tls(tls_boxed) => tls_boxed.read(buf),
}
}
}
impl io::Write for BidiStream {
fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
match self {
Self::Tcp(stream) => stream.write(buf),
Self::Tls(tls_boxed) => tls_boxed.write(buf),
}
}
fn flush(&mut self) -> io::Result<()> {
match self {
Self::Tcp(stream) => stream.flush(),
Self::Tls(tls_boxed) => tls_boxed.flush(),
}
}
}

View File

@@ -0,0 +1,238 @@
use std::{
collections::HashMap,
io::{Cursor, Read, Write},
net::{TcpListener, TcpStream},
sync::Arc,
};
use byteorder::{BigEndian, ReadBytesExt, WriteBytesExt};
use bytes::{Buf, BufMut, Bytes, BytesMut};
use once_cell::sync::Lazy;
use utils::{
postgres_backend::{AuthType, Handler, PostgresBackend},
postgres_backend_async::QueryError,
};
fn make_tcp_pair() -> (TcpStream, TcpStream) {
let listener = TcpListener::bind("127.0.0.1:0").unwrap();
let addr = listener.local_addr().unwrap();
let client_stream = TcpStream::connect(addr).unwrap();
let (server_stream, _) = listener.accept().unwrap();
(server_stream, client_stream)
}
static KEY: Lazy<rustls::PrivateKey> = Lazy::new(|| {
let mut cursor = Cursor::new(include_bytes!("key.pem"));
rustls::PrivateKey(rustls_pemfile::rsa_private_keys(&mut cursor).unwrap()[0].clone())
});
static CERT: Lazy<rustls::Certificate> = Lazy::new(|| {
let mut cursor = Cursor::new(include_bytes!("cert.pem"));
rustls::Certificate(rustls_pemfile::certs(&mut cursor).unwrap()[0].clone())
});
#[test]
// [false-positive](https://github.com/rust-lang/rust-clippy/issues/9274),
// we resize the vector so doing some modifications after all
#[allow(clippy::read_zero_byte_vec)]
fn ssl() {
let (mut client_sock, server_sock) = make_tcp_pair();
const QUERY: &str = "hello world";
let client_jh = std::thread::spawn(move || {
// SSLRequest
client_sock.write_u32::<BigEndian>(8).unwrap();
client_sock.write_u32::<BigEndian>(80877103).unwrap();
let ssl_response = client_sock.read_u8().unwrap();
assert_eq!(b'S', ssl_response);
let cfg = rustls::ClientConfig::builder()
.with_safe_defaults()
.with_root_certificates({
let mut store = rustls::RootCertStore::empty();
store.add(&CERT).unwrap();
store
})
.with_no_client_auth();
let client_config = Arc::new(cfg);
let dns_name = "localhost".try_into().unwrap();
let mut conn = rustls::ClientConnection::new(client_config, dns_name).unwrap();
conn.complete_io(&mut client_sock).unwrap();
assert!(!conn.is_handshaking());
let mut stream = rustls::Stream::new(&mut conn, &mut client_sock);
// StartupMessage
stream.write_u32::<BigEndian>(9).unwrap();
stream.write_u32::<BigEndian>(196608).unwrap();
stream.write_u8(0).unwrap();
stream.flush().unwrap();
// wait for ReadyForQuery
let mut msg_buf = Vec::new();
loop {
let msg = stream.read_u8().unwrap();
let size = stream.read_u32::<BigEndian>().unwrap() - 4;
msg_buf.resize(size as usize, 0);
stream.read_exact(&mut msg_buf).unwrap();
if msg == b'Z' {
// ReadyForQuery
break;
}
}
// Query
stream.write_u8(b'Q').unwrap();
stream
.write_u32::<BigEndian>(4u32 + QUERY.len() as u32)
.unwrap();
stream.write_all(QUERY.as_ref()).unwrap();
stream.flush().unwrap();
// ReadyForQuery
let msg = stream.read_u8().unwrap();
assert_eq!(msg, b'Z');
});
struct TestHandler {
got_query: bool,
}
impl Handler for TestHandler {
fn process_query(
&mut self,
_pgb: &mut PostgresBackend,
query_string: &str,
) -> Result<(), QueryError> {
self.got_query = query_string == QUERY;
Ok(())
}
}
let mut handler = TestHandler { got_query: false };
let cfg = rustls::ServerConfig::builder()
.with_safe_defaults()
.with_no_client_auth()
.with_single_cert(vec![CERT.clone()], KEY.clone())
.unwrap();
let tls_config = Some(Arc::new(cfg));
let pgb = PostgresBackend::new(server_sock, AuthType::Trust, tls_config, true).unwrap();
pgb.run(&mut handler).unwrap();
assert!(handler.got_query);
client_jh.join().unwrap();
// TODO consider shutdown behavior
}
#[test]
fn no_ssl() {
let (mut client_sock, server_sock) = make_tcp_pair();
let client_jh = std::thread::spawn(move || {
let mut buf = BytesMut::new();
// SSLRequest
buf.put_u32(8);
buf.put_u32(80877103);
client_sock.write_all(&buf).unwrap();
buf.clear();
let ssl_response = client_sock.read_u8().unwrap();
assert_eq!(b'N', ssl_response);
});
struct TestHandler;
impl Handler for TestHandler {
fn process_query(
&mut self,
_pgb: &mut PostgresBackend,
_query_string: &str,
) -> Result<(), QueryError> {
panic!()
}
}
let mut handler = TestHandler;
let pgb = PostgresBackend::new(server_sock, AuthType::Trust, None, true).unwrap();
pgb.run(&mut handler).unwrap();
client_jh.join().unwrap();
}
#[test]
fn server_forces_ssl() {
let (mut client_sock, server_sock) = make_tcp_pair();
let client_jh = std::thread::spawn(move || {
// StartupMessage
client_sock.write_u32::<BigEndian>(9).unwrap();
client_sock.write_u32::<BigEndian>(196608).unwrap();
client_sock.write_u8(0).unwrap();
client_sock.flush().unwrap();
// ErrorResponse
assert_eq!(client_sock.read_u8().unwrap(), b'E');
let len = client_sock.read_u32::<BigEndian>().unwrap() - 4;
let mut body = vec![0; len as usize];
client_sock.read_exact(&mut body).unwrap();
let mut body = Bytes::from(body);
let mut errors = HashMap::new();
loop {
let field_type = body.get_u8();
if field_type == 0u8 {
break;
}
let end_idx = body.iter().position(|&b| b == 0u8).unwrap();
let mut value = body.split_to(end_idx + 1);
assert_eq!(value[end_idx], 0u8);
value.truncate(end_idx);
let old = errors.insert(field_type, value);
assert!(old.is_none());
}
assert!(!body.has_remaining());
assert_eq!("must connect with TLS", errors.get(&b'M').unwrap());
// TODO read failure
});
struct TestHandler;
impl Handler for TestHandler {
fn process_query(
&mut self,
_pgb: &mut PostgresBackend,
_query_string: &str,
) -> Result<(), QueryError> {
panic!()
}
}
let mut handler = TestHandler;
let cfg = rustls::ServerConfig::builder()
.with_safe_defaults()
.with_no_client_auth()
.with_single_cert(vec![CERT.clone()], KEY.clone())
.unwrap();
let tls_config = Some(Arc::new(cfg));
let pgb = PostgresBackend::new(server_sock, AuthType::Trust, tls_config, true).unwrap();
let res = pgb.run(&mut handler).unwrap_err();
assert_eq!("client did not connect with TLS", format!("{}", res));
client_jh.join().unwrap();
// TODO consider shutdown behavior
}

4
libs/walproposer/.gitignore vendored Normal file
View File

@@ -0,0 +1,4 @@
*.a
*.o
*.tmp
pgdata

View File

@@ -0,0 +1,39 @@
[package]
name = "walproposer"
version = "0.1.0"
edition.workspace = true
license.workspace = true
[dependencies]
atty.workspace = true
rand.workspace = true
regex.workspace = true
bytes.workspace = true
byteorder.workspace = true
anyhow.workspace = true
crc32c.workspace = true
hex.workspace = true
once_cell.workspace = true
log.workspace = true
libc.workspace = true
memoffset.workspace = true
thiserror.workspace = true
tracing.workspace = true
tracing-subscriber = { workspace = true, features = ["json"] }
serde.workspace = true
scopeguard.workspace = true
utils.workspace = true
safekeeper.workspace = true
postgres_ffi.workspace = true
hyper.workspace = true
workspace_hack.workspace = true
[dev-dependencies]
env_logger.workspace = true
postgres.workspace = true
[build-dependencies]
anyhow.workspace = true
bindgen.workspace = true
cbindgen = "0.24.0"

View File

@@ -0,0 +1,16 @@
# walproposer Rust module
## Rust -> C
We compile walproposer as a static library and generate Rust bindings for it using `bindgen`.
Entrypoint header file is `bindgen_deps.h`.
## C -> Rust
We use `cbindgen` to generate C bindings for the Rust code. They are stored in `rust_bindings.h`.
## How to run the tests
```
export RUSTFLAGS="-C default-linker-libraries"
```

View File

@@ -0,0 +1,30 @@
/*
* This header file is the input to bindgen. It includes all the
* PostgreSQL headers that we need to auto-generate Rust structs
* from. If you need to expose a new struct to Rust code, add the
* header here, and whitelist the struct in the build.rs file.
*/
#include "c.h"
#include "walproposer.h"
#include <stdarg.h>
#include <stdbool.h>
#include <stdint.h>
#include <stdlib.h>
// Calc a sum of two numbers. Used to test Rust->C function calls.
int TestFunc(int a, int b);
// Run a client for simple simlib test.
void RunClientC(uint32_t serverId);
void WalProposerRust();
void WalProposerCleanup();
extern bool debug_enabled;
// Initialize global variables before calling any Postgres C code.
void MyContextInit();
XLogRecPtr MyInsertRecord();

137
libs/walproposer/build.rs Normal file
View File

@@ -0,0 +1,137 @@
use std::{env, path::PathBuf, process::Command};
use anyhow::{anyhow, Context};
use bindgen::CargoCallbacks;
extern crate bindgen;
fn main() -> anyhow::Result<()> {
let crate_dir = env::var("CARGO_MANIFEST_DIR").unwrap();
cbindgen::Builder::new()
.with_crate(crate_dir)
.with_language(cbindgen::Language::C)
.generate()
.expect("Unable to generate bindings")
.write_to_file("rust_bindings.h");
// Tell cargo to invalidate the built crate whenever the wrapper changes
println!("cargo:rerun-if-changed=bindgen_deps.h,test.c,../../pgxn/neon/walproposer.c,build.sh");
println!("cargo:rustc-link-arg=-Wl,--start-group");
println!("cargo:rustc-link-arg=-lsim");
println!("cargo:rustc-link-arg=-lpgport_srv");
println!("cargo:rustc-link-arg=-lpostgres");
println!("cargo:rustc-link-arg=-lpgcommon_srv");
println!("cargo:rustc-link-arg=-lssl");
println!("cargo:rustc-link-arg=-lcrypto");
println!("cargo:rustc-link-arg=-lz");
println!("cargo:rustc-link-arg=-lpthread");
println!("cargo:rustc-link-arg=-lrt");
println!("cargo:rustc-link-arg=-ldl");
println!("cargo:rustc-link-arg=-lm");
println!("cargo:rustc-link-arg=-lwalproposer");
println!("cargo:rustc-link-arg=-Wl,--end-group");
println!("cargo:rustc-link-search=/home/admin/simulator/libs/walproposer");
// disable fPIE
println!("cargo:rustc-link-arg=-no-pie");
// print output of build.sh
let output = std::process::Command::new("./build.sh")
.output()
.expect("could not spawn `clang`");
println!("stdout: {}", String::from_utf8(output.stdout).unwrap());
println!("stderr: {}", String::from_utf8(output.stderr).unwrap());
if !output.status.success() {
// Panic if the command was not successful.
panic!("could not compile object file");
}
// // Finding the location of C headers for the Postgres server:
// // - if POSTGRES_INSTALL_DIR is set look into it, otherwise look into `<project_root>/pg_install`
// // - if there's a `bin/pg_config` file use it for getting include server, otherwise use `<project_root>/pg_install/{PG_MAJORVERSION}/include/postgresql/server`
let pg_install_dir = if let Some(postgres_install_dir) = env::var_os("POSTGRES_INSTALL_DIR") {
postgres_install_dir.into()
} else {
PathBuf::from("pg_install")
};
let pg_version = "v15";
let mut pg_install_dir_versioned = pg_install_dir.join(pg_version);
if pg_install_dir_versioned.is_relative() {
let cwd = env::current_dir().context("Failed to get current_dir")?;
pg_install_dir_versioned = cwd.join("..").join("..").join(pg_install_dir_versioned);
}
let pg_config_bin = pg_install_dir_versioned
.join(pg_version)
.join("bin")
.join("pg_config");
let inc_server_path: String = if pg_config_bin.exists() {
let output = Command::new(pg_config_bin)
.arg("--includedir-server")
.output()
.context("failed to execute `pg_config --includedir-server`")?;
if !output.status.success() {
panic!("`pg_config --includedir-server` failed")
}
String::from_utf8(output.stdout)
.context("pg_config output is not UTF-8")?
.trim_end()
.into()
} else {
let server_path = pg_install_dir_versioned
.join("include")
.join("postgresql")
.join("server")
.into_os_string();
server_path
.into_string()
.map_err(|s| anyhow!("Bad postgres server path {s:?}"))?
};
let inc_pgxn_path = "/home/admin/simulator/pgxn/neon";
// The bindgen::Builder is the main entry point
// to bindgen, and lets you build up options for
// the resulting bindings.
let bindings = bindgen::Builder::default()
// The input header we would like to generate
// bindings for.
.header("bindgen_deps.h")
// Tell cargo to invalidate the built crate whenever any of the
// included header files changed.
.parse_callbacks(Box::new(CargoCallbacks))
.allowlist_function("TestFunc")
.allowlist_function("RunClientC")
.allowlist_function("WalProposerRust")
.allowlist_function("MyContextInit")
.allowlist_function("WalProposerCleanup")
.allowlist_function("MyInsertRecord")
.allowlist_var("wal_acceptors_list")
.allowlist_var("wal_acceptor_reconnect_timeout")
.allowlist_var("wal_acceptor_connection_timeout")
.allowlist_var("am_wal_proposer")
.allowlist_var("neon_timeline_walproposer")
.allowlist_var("neon_tenant_walproposer")
.allowlist_var("syncSafekeepers")
.allowlist_var("sim_redo_start_lsn")
.allowlist_var("debug_enabled")
.clang_arg(format!("-I{inc_server_path}"))
.clang_arg(format!("-I{inc_pgxn_path}"))
.clang_arg(format!("-DSIMLIB"))
// Finish the builder and generate the bindings.
.generate()
// Unwrap the Result and panic on failure.
.expect("Unable to generate bindings");
// Write the bindings to the $OUT_DIR/bindings.rs file.
let out_path = PathBuf::from(env::var("OUT_DIR").unwrap()).join("bindings.rs");
bindings
.write_to_file(out_path)
.expect("Couldn't write bindings!");
Ok(())
}

21
libs/walproposer/build.sh Executable file
View File

@@ -0,0 +1,21 @@
#!/bin/bash
set -e
cd /home/admin/simulator/libs/walproposer
# TODO: rewrite to Makefile
make -C ../.. neon-pg-ext-walproposer
make -C ../../pg_install/build/v15/src/backend postgres-lib -s
cp ../../pg_install/build/v15/src/backend/libpostgres.a .
cp ../../pg_install/build/v15/src/common/libpgcommon_srv.a .
cp ../../pg_install/build/v15/src/port/libpgport_srv.a .
clang -g -c libpqwalproposer.c test.c -ferror-limit=1 -I ../../pg_install/v15/include/postgresql/server -I ../../pgxn/neon
rm -rf libsim.a
ar rcs libsim.a test.o libpqwalproposer.o
rm -rf libwalproposer.a
PGXN_DIR=../../pg_install/build/neon-v15/
ar rcs libwalproposer.a $PGXN_DIR/walproposer.o $PGXN_DIR/walproposer_utils.o $PGXN_DIR/neon.o

View File

@@ -0,0 +1,542 @@
#include "postgres.h"
#include "neon.h"
#include "walproposer.h"
#include "rust_bindings.h"
#include "replication/message.h"
#include "access/xlog_internal.h"
// defined in walproposer.h
uint64 sim_redo_start_lsn;
XLogRecPtr sim_latest_available_lsn;
/* Header in walproposer.h -- Wrapper struct to abstract away the libpq connection */
struct WalProposerConn
{
int64_t tcp;
};
/* Helper function */
static bool
ensure_nonblocking_status(WalProposerConn *conn, bool is_nonblocking)
{
// walprop_log(LOG, "not implemented");
return false;
}
/* Exported function definitions */
char *
walprop_error_message(WalProposerConn *conn)
{
// walprop_log(LOG, "not implemented");
return NULL;
}
WalProposerConnStatusType
walprop_status(WalProposerConn *conn)
{
// walprop_log(LOG, "not implemented: walprop_status");
return WP_CONNECTION_OK;
}
WalProposerConn *
walprop_connect_start(char *conninfo)
{
WalProposerConn *conn;
walprop_log(LOG, "walprop_connect_start: %s", conninfo);
const char *connstr_prefix = "host=node port=";
Assert(strncmp(conninfo, connstr_prefix, strlen(connstr_prefix)) == 0);
int nodeId = atoi(conninfo + strlen(connstr_prefix));
conn = palloc(sizeof(WalProposerConn));
conn->tcp = sim_open_tcp(nodeId);
return conn;
}
WalProposerConnectPollStatusType
walprop_connect_poll(WalProposerConn *conn)
{
// walprop_log(LOG, "not implemented: walprop_connect_poll");
return WP_CONN_POLLING_OK;
}
bool
walprop_send_query(WalProposerConn *conn, char *query)
{
// walprop_log(LOG, "not implemented: walprop_send_query");
return true;
}
WalProposerExecStatusType
walprop_get_query_result(WalProposerConn *conn)
{
// walprop_log(LOG, "not implemented: walprop_get_query_result");
return WP_EXEC_SUCCESS_COPYBOTH;
}
pgsocket
walprop_socket(WalProposerConn *conn)
{
return (pgsocket) conn->tcp;
}
int
walprop_flush(WalProposerConn *conn)
{
// walprop_log(LOG, "not implemented");
return 0;
}
void
walprop_finish(WalProposerConn *conn)
{
// walprop_log(LOG, "walprop_finish not implemented");
}
/*
* Receive a message from the safekeeper.
*
* On success, the data is placed in *buf. It is valid until the next call
* to this function.
*/
PGAsyncReadResult
walprop_async_read(WalProposerConn *conn, char **buf, int *amount)
{
uintptr_t len;
char *msg;
Event event;
event = sim_epoll_peek(0);
if (event.tcp != conn->tcp || event.tag != Message || event.any_message != Bytes)
return PG_ASYNC_READ_TRY_AGAIN;
event = sim_epoll_rcv(0);
// walprop_log(LOG, "walprop_async_read, T: %d, tcp: %d, tag: %d", (int) event.tag, (int) event.tcp, (int) event.any_message);
Assert(event.tcp == conn->tcp);
Assert(event.tag == Message);
Assert(event.any_message == Bytes);
msg = (char*) sim_msg_get_bytes(&len);
*buf = msg;
*amount = len;
// walprop_log(LOG, "walprop_async_read: %d", (int) len);
return PG_ASYNC_READ_SUCCESS;
}
PGAsyncWriteResult
walprop_async_write(WalProposerConn *conn, void const *buf, size_t size)
{
// walprop_log(LOG, "walprop_async_write");
sim_msg_set_bytes(buf, size);
sim_tcp_send(conn->tcp);
return PG_ASYNC_WRITE_SUCCESS;
}
/*
* This function is very similar to walprop_async_write. For more
* information, refer to the comments there.
*/
bool
walprop_blocking_write(WalProposerConn *conn, void const *buf, size_t size)
{
// walprop_log(LOG, "walprop_blocking_write");
sim_msg_set_bytes(buf, size);
sim_tcp_send(conn->tcp);
return true;
}
void
sim_start_replication(XLogRecPtr startptr)
{
walprop_log(LOG, "sim_start_replication: %X/%X", LSN_FORMAT_ARGS(startptr));
sim_latest_available_lsn = startptr;
for (;;)
{
XLogRecPtr endptr = sim_latest_available_lsn;
Assert(startptr <= endptr);
if (endptr > startptr)
{
WalProposerBroadcast(startptr, endptr);
startptr = endptr;
}
WalProposerPoll();
}
}
#define UsableBytesInPage (XLOG_BLCKSZ - SizeOfXLogShortPHD)
static int UsableBytesInSegment =
(DEFAULT_XLOG_SEG_SIZE / XLOG_BLCKSZ * UsableBytesInPage) -
(SizeOfXLogLongPHD - SizeOfXLogShortPHD);
/*
* Converts a "usable byte position" to XLogRecPtr. A usable byte position
* is the position starting from the beginning of WAL, excluding all WAL
* page headers.
*/
static XLogRecPtr
XLogBytePosToRecPtr(uint64 bytepos)
{
uint64 fullsegs;
uint64 fullpages;
uint64 bytesleft;
uint32 seg_offset;
XLogRecPtr result;
fullsegs = bytepos / UsableBytesInSegment;
bytesleft = bytepos % UsableBytesInSegment;
if (bytesleft < XLOG_BLCKSZ - SizeOfXLogLongPHD)
{
/* fits on first page of segment */
seg_offset = bytesleft + SizeOfXLogLongPHD;
}
else
{
/* account for the first page on segment with long header */
seg_offset = XLOG_BLCKSZ;
bytesleft -= XLOG_BLCKSZ - SizeOfXLogLongPHD;
fullpages = bytesleft / UsableBytesInPage;
bytesleft = bytesleft % UsableBytesInPage;
seg_offset += fullpages * XLOG_BLCKSZ + bytesleft + SizeOfXLogShortPHD;
}
XLogSegNoOffsetToRecPtr(fullsegs, seg_offset, wal_segment_size, result);
return result;
}
/*
* Convert an XLogRecPtr to a "usable byte position".
*/
static uint64
XLogRecPtrToBytePos(XLogRecPtr ptr)
{
uint64 fullsegs;
uint32 fullpages;
uint32 offset;
uint64 result;
XLByteToSeg(ptr, fullsegs, wal_segment_size);
fullpages = (XLogSegmentOffset(ptr, wal_segment_size)) / XLOG_BLCKSZ;
offset = ptr % XLOG_BLCKSZ;
if (fullpages == 0)
{
result = fullsegs * UsableBytesInSegment;
if (offset > 0)
{
Assert(offset >= SizeOfXLogLongPHD);
result += offset - SizeOfXLogLongPHD;
}
}
else
{
result = fullsegs * UsableBytesInSegment +
(XLOG_BLCKSZ - SizeOfXLogLongPHD) + /* account for first page */
(fullpages - 1) * UsableBytesInPage; /* full pages */
if (offset > 0)
{
Assert(offset >= SizeOfXLogShortPHD);
result += offset - SizeOfXLogShortPHD;
}
}
return result;
}
#define max_rdatas 16
void InitMyInsert();
static void MyBeginInsert();
static void MyRegisterData(char *data, int len);
static XLogRecPtr MyFinishInsert(RmgrId rmid, uint8 info, uint8 flags);
static void MyCopyXLogRecordToWAL(int write_len, XLogRecData *rdata, XLogRecPtr StartPos, XLogRecPtr EndPos);
/*
* An array of XLogRecData structs, to hold registered data.
*/
static XLogRecData rdatas[max_rdatas];
static int num_rdatas; /* entries currently used */
static uint32 mainrdata_len; /* total # of bytes in chain */
static XLogRecData hdr_rdt;
static char hdr_scratch[16000];
static XLogRecPtr CurrBytePos;
static XLogRecPtr PrevBytePos;
void InitMyInsert()
{
CurrBytePos = sim_redo_start_lsn;
PrevBytePos = InvalidXLogRecPtr;
sim_latest_available_lsn = sim_redo_start_lsn;
}
static void MyBeginInsert()
{
num_rdatas = 0;
mainrdata_len = 0;
}
static void MyRegisterData(char *data, int len)
{
XLogRecData *rdata;
if (num_rdatas >= max_rdatas)
walprop_log(ERROR, "too much WAL data");
rdata = &rdatas[num_rdatas++];
rdata->data = data;
rdata->len = len;
rdata->next = NULL;
if (num_rdatas > 1) {
rdatas[num_rdatas - 2].next = rdata;
}
mainrdata_len += len;
}
static XLogRecPtr
MyFinishInsert(RmgrId rmid, uint8 info, uint8 flags)
{
XLogRecData *rdt;
uint32 total_len = 0;
int block_id;
pg_crc32c rdata_crc;
XLogRecord *rechdr;
char *scratch = hdr_scratch;
int size;
XLogRecPtr StartPos;
XLogRecPtr EndPos;
uint64 startbytepos;
uint64 endbytepos;
/*
* Note: this function can be called multiple times for the same record.
* All the modifications we do to the rdata chains below must handle that.
*/
/* The record begins with the fixed-size header */
rechdr = (XLogRecord *) scratch;
scratch += SizeOfXLogRecord;
hdr_rdt.data = hdr_scratch;
if (num_rdatas > 0)
{
hdr_rdt.next = &rdatas[0];
}
else
{
hdr_rdt.next = NULL;
}
/* followed by main data, if any */
if (mainrdata_len > 0)
{
if (mainrdata_len > 255)
{
*(scratch++) = (char) XLR_BLOCK_ID_DATA_LONG;
memcpy(scratch, &mainrdata_len, sizeof(uint32));
scratch += sizeof(uint32);
}
else
{
*(scratch++) = (char) XLR_BLOCK_ID_DATA_SHORT;
*(scratch++) = (uint8) mainrdata_len;
}
total_len += mainrdata_len;
}
hdr_rdt.len = (scratch - hdr_scratch);
total_len += hdr_rdt.len;
/*
* Calculate CRC of the data
*
* Note that the record header isn't added into the CRC initially since we
* don't know the prev-link yet. Thus, the CRC will represent the CRC of
* the whole record in the order: rdata, then backup blocks, then record
* header.
*/
INIT_CRC32C(rdata_crc);
COMP_CRC32C(rdata_crc, hdr_scratch + SizeOfXLogRecord, hdr_rdt.len - SizeOfXLogRecord);
for (size_t i = 0; i < num_rdatas; i++)
{
rdt = &rdatas[i];
COMP_CRC32C(rdata_crc, rdt->data, rdt->len);
}
/*
* Fill in the fields in the record header. Prev-link is filled in later,
* once we know where in the WAL the record will be inserted. The CRC does
* not include the record header yet.
*/
rechdr->xl_xid = 0;
rechdr->xl_tot_len = total_len;
rechdr->xl_info = info;
rechdr->xl_rmid = rmid;
rechdr->xl_prev = InvalidXLogRecPtr;
rechdr->xl_crc = rdata_crc;
size = MAXALIGN(rechdr->xl_tot_len);
/* All (non xlog-switch) records should contain data. */
Assert(size > SizeOfXLogRecord);
startbytepos = XLogRecPtrToBytePos(CurrBytePos);
endbytepos = startbytepos + size;
// Get the position.
StartPos = XLogBytePosToRecPtr(startbytepos);
EndPos = XLogBytePosToRecPtr(startbytepos + size);
rechdr->xl_prev = PrevBytePos;
Assert(XLogRecPtrToBytePos(StartPos) == startbytepos);
Assert(XLogRecPtrToBytePos(EndPos) == endbytepos);
// Update global pointers.
CurrBytePos = EndPos;
PrevBytePos = StartPos;
/*
* Now that xl_prev has been filled in, calculate CRC of the record
* header.
*/
rdata_crc = rechdr->xl_crc;
COMP_CRC32C(rdata_crc, rechdr, offsetof(XLogRecord, xl_crc));
FIN_CRC32C(rdata_crc);
rechdr->xl_crc = rdata_crc;
// Now write it to disk.
MyCopyXLogRecordToWAL(rechdr->xl_tot_len, &hdr_rdt, StartPos, EndPos);
return EndPos;
}
#define INSERT_FREESPACE(endptr) \
(((endptr) % XLOG_BLCKSZ == 0) ? 0 : (XLOG_BLCKSZ - (endptr) % XLOG_BLCKSZ))
static void
MyCopyXLogRecordToWAL(int write_len, XLogRecData *rdata, XLogRecPtr StartPos, XLogRecPtr EndPos)
{
XLogRecPtr CurrPos;
int written;
int freespace;
// Write hdr_rdt and `num_rdatas` other datas.
CurrPos = StartPos;
freespace = INSERT_FREESPACE(CurrPos);
written = 0;
Assert(freespace >= sizeof(uint32));
while (rdata != NULL)
{
char *rdata_data = rdata->data;
int rdata_len = rdata->len;
while (rdata_len >= freespace)
{
char header_buf[SizeOfXLogLongPHD];
XLogPageHeader NewPage = (XLogPageHeader) header_buf;
Assert(CurrPos % XLOG_BLCKSZ >= SizeOfXLogShortPHD || freespace == 0);
XLogWalPropWrite(rdata_data, freespace, CurrPos);
rdata_data += freespace;
rdata_len -= freespace;
written += freespace;
CurrPos += freespace;
// Init new page
MemSet(header_buf, 0, SizeOfXLogLongPHD);
/*
* Fill the new page's header
*/
NewPage->xlp_magic = XLOG_PAGE_MAGIC;
/* NewPage->xlp_info = 0; */ /* done by memset */
NewPage->xlp_tli = 1;
NewPage->xlp_pageaddr = CurrPos;
/* NewPage->xlp_rem_len = 0; */ /* done by memset */
NewPage->xlp_info |= XLP_BKP_REMOVABLE;
/*
* If first page of an XLOG segment file, make it a long header.
*/
if ((XLogSegmentOffset(NewPage->xlp_pageaddr, wal_segment_size)) == 0)
{
XLogLongPageHeader NewLongPage = (XLogLongPageHeader) NewPage;
NewLongPage->xlp_sysid = 0;
NewLongPage->xlp_seg_size = wal_segment_size;
NewLongPage->xlp_xlog_blcksz = XLOG_BLCKSZ;
NewPage->xlp_info |= XLP_LONG_HEADER;
}
NewPage->xlp_rem_len = write_len - written;
if (NewPage->xlp_rem_len > 0) {
NewPage->xlp_info |= XLP_FIRST_IS_CONTRECORD;
}
/* skip over the page header */
if (XLogSegmentOffset(CurrPos, wal_segment_size) == 0)
{
XLogWalPropWrite(header_buf, SizeOfXLogLongPHD, CurrPos);
CurrPos += SizeOfXLogLongPHD;
}
else
{
XLogWalPropWrite(header_buf, SizeOfXLogShortPHD, CurrPos);
CurrPos += SizeOfXLogShortPHD;
}
freespace = INSERT_FREESPACE(CurrPos);
}
Assert(CurrPos % XLOG_BLCKSZ >= SizeOfXLogShortPHD || rdata_len == 0);
XLogWalPropWrite(rdata_data, rdata_len, CurrPos);
CurrPos += rdata_len;
written += rdata_len;
freespace -= rdata_len;
rdata = rdata->next;
}
Assert(written == write_len);
CurrPos = MAXALIGN64(CurrPos);
Assert(CurrPos == EndPos);
}
XLogRecPtr MyInsertRecord()
{
const char *prefix = "prefix";
const char *message = "message";
size_t size = 7;
bool transactional = false;
xl_logical_message xlrec;
xlrec.dbId = 0;
xlrec.transactional = transactional;
/* trailing zero is critical; see logicalmsg_desc */
xlrec.prefix_size = strlen(prefix) + 1;
xlrec.message_size = size;
MyBeginInsert();
MyRegisterData((char *) &xlrec, SizeOfLogicalMessage);
MyRegisterData(unconstify(char *, prefix), xlrec.prefix_size);
MyRegisterData(unconstify(char *, message), size);
return MyFinishInsert(RM_LOGICALMSG_ID, XLOG_LOGICAL_MESSAGE, XLOG_INCLUDE_ORIGIN);
}

View File

@@ -0,0 +1,106 @@
#include <stdarg.h>
#include <stdbool.h>
#include <stdint.h>
#include <stdlib.h>
/**
* List of all possible AnyMessage.
*/
enum AnyMessageTag {
None,
InternalConnect,
Just32,
ReplCell,
Bytes,
LSN,
};
typedef uint8_t AnyMessageTag;
/**
* List of all possible NodeEvent.
*/
enum EventTag {
Timeout,
Accept,
Closed,
Message,
Internal,
};
typedef uint8_t EventTag;
/**
* Event returned by epoll_recv.
*/
typedef struct Event {
EventTag tag;
int64_t tcp;
AnyMessageTag any_message;
} Event;
void rust_function(uint32_t a);
/**
* C API for the node os.
*/
void sim_sleep(uint64_t ms);
uint64_t sim_random(uint64_t max);
uint32_t sim_id(void);
int64_t sim_open_tcp(uint32_t dst);
int64_t sim_open_tcp_nopoll(uint32_t dst);
/**
* Send MESSAGE_BUF content to the given tcp.
*/
void sim_tcp_send(int64_t tcp);
/**
* Receive a message from the given tcp. Can be used only with tcp opened with
* `sim_open_tcp_nopoll`.
*/
struct Event sim_tcp_recv(int64_t tcp);
struct Event sim_epoll_rcv(int64_t timeout);
struct Event sim_epoll_peek(int64_t timeout);
int64_t sim_now(void);
void sim_exit(int32_t code, const uint8_t *msg);
void sim_set_result(int32_t code, const uint8_t *msg);
void sim_log_event(const int8_t *msg);
/**
* Get tag of the current message.
*/
AnyMessageTag sim_msg_tag(void);
/**
* Read AnyMessage::Just32 message.
*/
void sim_msg_get_just_u32(uint32_t *val);
/**
* Read AnyMessage::LSN message.
*/
void sim_msg_get_lsn(uint64_t *val);
/**
* Write AnyMessage::ReplCell message.
*/
void sim_msg_set_repl_cell(uint32_t value, uint32_t client_id, uint32_t seqno);
/**
* Write AnyMessage::Bytes message.
*/
void sim_msg_set_bytes(const char *bytes, uintptr_t len);
/**
* Read AnyMessage::Bytes message.
*/
const char *sim_msg_get_bytes(uintptr_t *len);

View File

@@ -0,0 +1,36 @@
#![allow(non_upper_case_globals)]
#![allow(non_camel_case_types)]
#![allow(non_snake_case)]
use safekeeper::simlib::node_os::NodeOs;
use tracing::info;
pub mod bindings {
include!(concat!(env!("OUT_DIR"), "/bindings.rs"));
}
#[no_mangle]
pub extern "C" fn rust_function(a: u32) {
info!("Hello from Rust!");
info!("a: {}", a);
}
pub mod sim;
pub mod sim_proto;
#[cfg(test)]
mod test;
#[cfg(test)]
pub mod simtest;
pub fn c_context() -> Option<Box<dyn Fn(NodeOs) + Send + Sync>> {
Some(Box::new(|os: NodeOs| {
sim::c_attach_node_os(os);
unsafe { bindings::MyContextInit(); }
}))
}
pub fn enable_debug() {
unsafe { bindings::debug_enabled = true; }
}

240
libs/walproposer/src/sim.rs Normal file
View File

@@ -0,0 +1,240 @@
use log::debug;
use safekeeper::simlib::{network::TCP, node_os::NodeOs, world::NodeEvent};
use std::{
cell::RefCell,
collections::HashMap,
ffi::{CStr, CString},
};
use tracing::trace;
use crate::sim_proto::{anymessage_tag, AnyMessageTag, Event, EventTag, MESSAGE_BUF};
thread_local! {
static CURRENT_NODE_OS: RefCell<Option<NodeOs>> = RefCell::new(None);
static TCP_CACHE: RefCell<HashMap<i64, TCP>> = RefCell::new(HashMap::new());
}
/// Get the current node os.
fn os() -> NodeOs {
CURRENT_NODE_OS.with(|cell| cell.borrow().clone().expect("no node os set"))
}
fn tcp_save(tcp: TCP) -> i64 {
TCP_CACHE.with(|cell| {
let mut cache = cell.borrow_mut();
let id = tcp.id();
cache.insert(id, tcp);
id
})
}
fn tcp_load(id: i64) -> TCP {
TCP_CACHE.with(|cell| {
let cache = cell.borrow();
cache.get(&id).expect("unknown TCP id").clone()
})
}
/// Should be called before calling any of the C functions.
pub(crate) fn c_attach_node_os(os: NodeOs) {
CURRENT_NODE_OS.with(|cell| {
*cell.borrow_mut() = Some(os);
});
TCP_CACHE.with(|cell| {
*cell.borrow_mut() = HashMap::new();
});
}
/// C API for the node os.
#[no_mangle]
pub extern "C" fn sim_sleep(ms: u64) {
os().sleep(ms);
}
#[no_mangle]
pub extern "C" fn sim_random(max: u64) -> u64 {
os().random(max)
}
#[no_mangle]
pub extern "C" fn sim_id() -> u32 {
os().id().into()
}
#[no_mangle]
pub extern "C" fn sim_open_tcp(dst: u32) -> i64 {
tcp_save(os().open_tcp(dst.into()))
}
#[no_mangle]
pub extern "C" fn sim_open_tcp_nopoll(dst: u32) -> i64 {
tcp_save(os().open_tcp_nopoll(dst.into()))
}
#[no_mangle]
/// Send MESSAGE_BUF content to the given tcp.
pub extern "C" fn sim_tcp_send(tcp: i64) {
tcp_load(tcp).send(MESSAGE_BUF.with(|cell| cell.borrow().clone()));
}
#[no_mangle]
/// Receive a message from the given tcp. Can be used only with tcp opened with
/// `sim_open_tcp_nopoll`.
pub extern "C" fn sim_tcp_recv(tcp: i64) -> Event {
let event = tcp_load(tcp).recv();
match event {
NodeEvent::Accept(_) => unreachable!(),
NodeEvent::Closed(_) => Event {
tag: EventTag::Closed,
tcp: 0,
any_message: AnyMessageTag::None,
},
NodeEvent::Internal(_) => unreachable!(),
NodeEvent::Message((message, _)) => {
// store message in thread local storage, C code should use
// sim_msg_* functions to access it.
MESSAGE_BUF.with(|cell| {
*cell.borrow_mut() = message.clone();
});
Event {
tag: EventTag::Message,
tcp: 0,
any_message: anymessage_tag(&message),
}
}
NodeEvent::WakeTimeout(_) => unreachable!(),
}
}
#[no_mangle]
pub extern "C" fn sim_epoll_rcv(timeout: i64) -> Event {
let event = os().epoll_recv(timeout);
let event = if let Some(event) = event {
event
} else {
return Event {
tag: EventTag::Timeout,
tcp: 0,
any_message: AnyMessageTag::None,
};
};
match event {
NodeEvent::Accept(tcp) => Event {
tag: EventTag::Accept,
tcp: tcp_save(tcp),
any_message: AnyMessageTag::None,
},
NodeEvent::Closed(tcp) => Event {
tag: EventTag::Closed,
tcp: tcp_save(tcp),
any_message: AnyMessageTag::None,
},
NodeEvent::Message((message, tcp)) => {
// store message in thread local storage, C code should use
// sim_msg_* functions to access it.
MESSAGE_BUF.with(|cell| {
*cell.borrow_mut() = message.clone();
});
Event {
tag: EventTag::Message,
tcp: tcp_save(tcp),
any_message: anymessage_tag(&message),
}
}
NodeEvent::Internal(message) => {
// store message in thread local storage, C code should use
// sim_msg_* functions to access it.
MESSAGE_BUF.with(|cell| {
*cell.borrow_mut() = message.clone();
});
Event {
tag: EventTag::Internal,
tcp: 0,
any_message: anymessage_tag(&message),
}
}
NodeEvent::WakeTimeout(_) => {
// can't happen
unreachable!()
}
}
}
#[no_mangle]
pub extern "C" fn sim_epoll_peek(timeout: i64) -> Event {
let event = os().epoll_peek(timeout);
let event = if let Some(event) = event {
event
} else {
return Event {
tag: EventTag::Timeout,
tcp: 0,
any_message: AnyMessageTag::None,
};
};
match event {
NodeEvent::Accept(tcp) => Event {
tag: EventTag::Accept,
tcp: tcp_save(tcp),
any_message: AnyMessageTag::None,
},
NodeEvent::Closed(tcp) => Event {
tag: EventTag::Closed,
tcp: tcp_save(tcp),
any_message: AnyMessageTag::None,
},
NodeEvent::Message((message, tcp)) => Event {
tag: EventTag::Message,
tcp: tcp_save(tcp),
any_message: anymessage_tag(&message),
},
NodeEvent::Internal(message) => Event {
tag: EventTag::Internal,
tcp: 0,
any_message: anymessage_tag(&message),
},
NodeEvent::WakeTimeout(_) => {
// can't happen
unreachable!()
}
}
}
#[no_mangle]
pub extern "C" fn sim_now() -> i64 {
os().now() as i64
}
#[no_mangle]
pub extern "C" fn sim_exit(code: i32, msg: *const u8) {
trace!("sim_exit({}, {:?})", code, msg);
sim_set_result(code, msg);
// I tried to make use of pthread_exit, but it doesn't work.
// https://github.com/rust-lang/unsafe-code-guidelines/issues/211
// unsafe { libc::pthread_exit(std::ptr::null_mut()) };
// https://doc.rust-lang.org/nomicon/unwinding.html
// Everyone on the internet saying this is UB, but it works for me,
// so I'm going to use it for now.
panic!("sim_exit() called from C code")
}
#[no_mangle]
pub extern "C" fn sim_set_result(code: i32, msg: *const u8) {
let msg = unsafe { CStr::from_ptr(msg as *const i8) };
let msg = msg.to_string_lossy().into_owned();
debug!("sim_set_result({}, {:?})", code, msg);
os().set_result(code, msg);
}
#[no_mangle]
pub extern "C" fn sim_log_event(msg: *const i8) {
let msg = unsafe { CStr::from_ptr(msg) };
let msg = msg.to_string_lossy().into_owned();
debug!("sim_log_event({:?})", msg);
os().log_event(msg);
}

View File

@@ -0,0 +1,114 @@
use safekeeper::simlib::proto::{AnyMessage, ReplCell};
use std::{cell::RefCell, ffi::c_char};
pub(crate) fn anymessage_tag(msg: &AnyMessage) -> AnyMessageTag {
match msg {
AnyMessage::None => AnyMessageTag::None,
AnyMessage::InternalConnect => AnyMessageTag::InternalConnect,
AnyMessage::Just32(_) => AnyMessageTag::Just32,
AnyMessage::ReplCell(_) => AnyMessageTag::ReplCell,
AnyMessage::Bytes(_) => AnyMessageTag::Bytes,
AnyMessage::LSN(_) => AnyMessageTag::LSN,
}
}
thread_local! {
pub static MESSAGE_BUF: RefCell<AnyMessage> = RefCell::new(AnyMessage::None);
}
#[no_mangle]
/// Get tag of the current message.
pub extern "C" fn sim_msg_tag() -> AnyMessageTag {
MESSAGE_BUF.with(|cell| anymessage_tag(&*cell.borrow()))
}
#[no_mangle]
/// Read AnyMessage::Just32 message.
pub extern "C" fn sim_msg_get_just_u32(val: &mut u32) {
MESSAGE_BUF.with(|cell| match &*cell.borrow() {
AnyMessage::Just32(v) => {
*val = *v;
}
_ => panic!("expected Just32 message"),
});
}
#[no_mangle]
/// Read AnyMessage::LSN message.
pub extern "C" fn sim_msg_get_lsn(val: &mut u64) {
MESSAGE_BUF.with(|cell| match &*cell.borrow() {
AnyMessage::LSN(v) => {
*val = *v;
}
_ => panic!("expected LSN message"),
});
}
#[no_mangle]
/// Write AnyMessage::ReplCell message.
pub extern "C" fn sim_msg_set_repl_cell(value: u32, client_id: u32, seqno: u32) {
MESSAGE_BUF.with(|cell| {
*cell.borrow_mut() = AnyMessage::ReplCell(ReplCell {
value,
client_id,
seqno,
});
});
}
#[no_mangle]
/// Write AnyMessage::Bytes message.
pub extern "C" fn sim_msg_set_bytes(bytes: *const c_char, len: usize) {
MESSAGE_BUF.with(|cell| {
// copy bytes to a Rust Vec
let mut v: Vec<u8> = Vec::with_capacity(len);
unsafe {
v.set_len(len);
std::ptr::copy_nonoverlapping(bytes as *const u8, v.as_mut_ptr(), len);
}
*cell.borrow_mut() = AnyMessage::Bytes(v.into());
});
}
#[no_mangle]
/// Read AnyMessage::Bytes message.
pub extern "C" fn sim_msg_get_bytes(len: *mut usize) -> *const c_char {
MESSAGE_BUF.with(|cell| match &*cell.borrow() {
AnyMessage::Bytes(v) => {
unsafe {
*len = v.len();
v.as_ptr() as *const i8
}
}
_ => panic!("expected Bytes message"),
})
}
#[repr(C)]
/// Event returned by epoll_recv.
pub struct Event {
pub tag: EventTag,
pub tcp: i64,
pub any_message: AnyMessageTag,
}
#[repr(u8)]
/// List of all possible NodeEvent.
pub enum EventTag {
Timeout,
Accept,
Closed,
Message,
Internal,
}
#[repr(u8)]
/// List of all possible AnyMessage.
pub enum AnyMessageTag {
None,
InternalConnect,
Just32,
ReplCell,
Bytes,
LSN,
}

View File

@@ -0,0 +1,88 @@
use std::collections::HashMap;
use std::sync::Arc;
use safekeeper::safekeeper::SafeKeeperState;
use safekeeper::simlib::sync::Mutex;
use utils::id::TenantTimelineId;
pub struct Disk {
pub timelines: Mutex<HashMap<TenantTimelineId, Arc<TimelineDisk>>>,
}
impl Disk {
pub fn new() -> Self {
Disk {
timelines: Mutex::new(HashMap::new()),
}
}
pub fn put_state(&self, ttid: &TenantTimelineId, state: SafeKeeperState) -> Arc<TimelineDisk> {
self.timelines
.lock()
.entry(ttid.clone())
.and_modify(|e| {
let mut mu = e.state.lock();
*mu = state.clone();
})
.or_insert_with(|| {
Arc::new(TimelineDisk {
state: Mutex::new(state),
wal: Mutex::new(BlockStorage::new()),
})
})
.clone()
}
}
pub struct TimelineDisk {
pub state: Mutex<SafeKeeperState>,
pub wal: Mutex<BlockStorage>,
}
const BLOCK_SIZE: usize = 8192;
pub struct BlockStorage {
blocks: HashMap<u64, [u8; BLOCK_SIZE]>,
}
impl BlockStorage {
pub fn new() -> Self {
BlockStorage {
blocks: HashMap::new(),
}
}
pub fn read(&self, pos: u64, buf: &mut [u8]) {
let mut buf_offset = 0;
let mut storage_pos = pos;
while buf_offset < buf.len() {
let block_id = storage_pos / BLOCK_SIZE as u64;
let block = self.blocks.get(&block_id).unwrap_or(&[0; BLOCK_SIZE]);
let block_offset = storage_pos % BLOCK_SIZE as u64;
let block_len = BLOCK_SIZE as u64 - block_offset;
let buf_len = buf.len() - buf_offset;
let copy_len = std::cmp::min(block_len as usize, buf_len);
buf[buf_offset..buf_offset + copy_len]
.copy_from_slice(&block[block_offset as usize..block_offset as usize + copy_len]);
buf_offset += copy_len;
storage_pos += copy_len as u64;
}
}
pub fn write(&mut self, pos: u64, buf: &[u8]) {
let mut buf_offset = 0;
let mut storage_pos = pos;
while buf_offset < buf.len() {
let block_id = storage_pos / BLOCK_SIZE as u64;
let block = self.blocks.entry(block_id).or_insert([0; BLOCK_SIZE]);
let block_offset = storage_pos % BLOCK_SIZE as u64;
let block_len = BLOCK_SIZE as u64 - block_offset;
let buf_len = buf.len() - buf_offset;
let copy_len = std::cmp::min(block_len as usize, buf_len);
block[block_offset as usize..block_offset as usize + copy_len]
.copy_from_slice(&buf[buf_offset..buf_offset + copy_len]);
buf_offset += copy_len;
storage_pos += copy_len as u64
}
}
}

View File

@@ -0,0 +1,61 @@
use std::{sync::Arc, fmt};
use safekeeper::simlib::{world::World, sync::Mutex};
use tracing_subscriber::fmt::{time::FormatTime, format::Writer};
use utils::logging;
use crate::bindings;
#[derive(Clone)]
pub struct SimClock {
world_ptr: Arc<Mutex<Option<Arc<World>>>>,
}
impl Default for SimClock {
fn default() -> Self {
SimClock {
world_ptr: Arc::new(Mutex::new(None)),
}
}
}
impl SimClock {
pub fn set_world(&self, world: Arc<World>) {
*self.world_ptr.lock() = Some(world);
}
}
impl FormatTime for SimClock {
fn format_time(&self, w: &mut Writer<'_>) -> fmt::Result {
let world = self.world_ptr.lock().clone();
if let Some(world) = world {
let now = world.now();
write!(w, "[{}]", now)
} else {
write!(w, "[?]")
}
}
}
pub fn init_logger() -> SimClock {
let debug_enabled = unsafe { bindings::debug_enabled };
let clock = SimClock::default();
let base_logger = tracing_subscriber::fmt()
.with_target(false)
.with_timer(clock.clone())
.with_ansi(true)
.with_max_level(match debug_enabled {
true => tracing::Level::DEBUG,
false => tracing::Level::INFO,
})
.with_writer(std::io::stdout);
base_logger.init();
// logging::replace_panic_hook_with_tracing_panic_hook().forget();
std::panic::set_hook(Box::new(|_| {}));
clock
}

View File

@@ -0,0 +1,11 @@
#[cfg(test)]
pub mod simple_client;
#[cfg(test)]
pub mod wp_sk;
pub mod disk;
pub mod safekeeper;
pub mod storage;
pub mod log;
pub mod util;

View File

@@ -0,0 +1,372 @@
//! Safekeeper communication endpoint to WAL proposer (compute node).
//! Gets messages from the network, passes them down to consensus module and
//! sends replies back.
use std::{collections::HashMap, path::PathBuf, sync::Arc, time::Duration};
use anyhow::{anyhow, bail, Result};
use bytes::{Bytes, BytesMut};
use hyper::Uri;
use log::info;
use safekeeper::{
safekeeper::{
ProposerAcceptorMessage, SafeKeeper, SafeKeeperState, ServerInfo, UNKNOWN_SERVER_VERSION,
},
simlib::{network::TCP, node_os::NodeOs, proto::AnyMessage, world::NodeEvent},
timeline::TimelineError,
SafeKeeperConf, wal_storage::Storage,
};
use tracing::{debug, info_span};
use utils::{
id::{NodeId, TenantId, TenantTimelineId, TimelineId},
lsn::Lsn,
};
use crate::simtest::storage::DiskStateStorage;
use super::{
disk::{Disk, TimelineDisk},
storage::DiskWALStorage,
};
struct ConnState {
tcp: TCP,
greeting: bool,
ttid: TenantTimelineId,
flush_pending: bool,
}
struct SharedState {
sk: SafeKeeper<DiskStateStorage, DiskWALStorage>,
disk: Arc<TimelineDisk>,
}
struct GlobalMap {
timelines: HashMap<TenantTimelineId, SharedState>,
conf: SafeKeeperConf,
disk: Arc<Disk>,
}
impl GlobalMap {
fn new(disk: Arc<Disk>, conf: SafeKeeperConf) -> Result<Self> {
let mut timelines = HashMap::new();
for (&ttid, disk) in disk.timelines.lock().iter() {
debug!("loading timeline {}", ttid);
let state = disk.state.lock().clone();
if state.server.wal_seg_size == 0 {
bail!(TimelineError::UninitializedWalSegSize(ttid));
}
if state.server.pg_version == UNKNOWN_SERVER_VERSION {
bail!(TimelineError::UninitialinzedPgVersion(ttid));
}
if state.commit_lsn < state.local_start_lsn {
bail!(
"commit_lsn {} is higher than local_start_lsn {}",
state.commit_lsn,
state.local_start_lsn
);
}
let control_store = DiskStateStorage::new(disk.clone());
let wal_store = DiskWALStorage::new(disk.clone(), &control_store)?;
let sk = SafeKeeper::new(control_store, wal_store, conf.my_id)?;
timelines.insert(
ttid.clone(),
SharedState {
sk,
disk: disk.clone(),
},
);
}
Ok(Self {
timelines,
conf,
disk,
})
}
fn create(&mut self, ttid: TenantTimelineId, server_info: ServerInfo) -> Result<()> {
if self.timelines.contains_key(&ttid) {
bail!("timeline {} already exists", ttid);
}
debug!("creating new timeline {}", ttid);
let commit_lsn = Lsn::INVALID;
let local_start_lsn = Lsn::INVALID;
// TODO: load state from in-memory storage
let state = SafeKeeperState::new(&ttid, server_info, vec![], commit_lsn, local_start_lsn);
if state.server.wal_seg_size == 0 {
bail!(TimelineError::UninitializedWalSegSize(ttid));
}
if state.server.pg_version == UNKNOWN_SERVER_VERSION {
bail!(TimelineError::UninitialinzedPgVersion(ttid));
}
if state.commit_lsn < state.local_start_lsn {
bail!(
"commit_lsn {} is higher than local_start_lsn {}",
state.commit_lsn,
state.local_start_lsn
);
}
let disk_timeline = self.disk.put_state(&ttid, state);
let control_store = DiskStateStorage::new(disk_timeline.clone());
let wal_store = DiskWALStorage::new(disk_timeline.clone(), &control_store)?;
let sk = SafeKeeper::new(control_store, wal_store, self.conf.my_id)?;
self.timelines.insert(
ttid.clone(),
SharedState {
sk,
disk: disk_timeline,
},
);
Ok(())
}
fn get(&mut self, ttid: &TenantTimelineId) -> &mut SharedState {
self.timelines.get_mut(ttid).expect("timeline must exist")
}
fn has_tli(&self, ttid: &TenantTimelineId) -> bool {
self.timelines.contains_key(ttid)
}
}
pub fn run_server(os: NodeOs, disk: Arc<Disk>) -> Result<()> {
let _enter = info_span!("safekeeper", id = os.id()).entered();
debug!("started server");
os.log_event("started;safekeeper".to_owned());
let conf = SafeKeeperConf {
workdir: PathBuf::from("."),
my_id: NodeId(os.id() as u64),
listen_pg_addr: String::new(),
listen_http_addr: String::new(),
no_sync: false,
broker_endpoint: "/".parse::<Uri>().unwrap(),
broker_keepalive_interval: Duration::from_secs(0),
heartbeat_timeout: Duration::from_secs(0),
remote_storage: None,
max_offloader_lag_bytes: 0,
backup_runtime_threads: None,
wal_backup_enabled: false,
auth: None,
};
let mut global = GlobalMap::new(disk, conf.clone())?;
let mut conns: HashMap<i64, ConnState> = HashMap::new();
for (&ttid, shared_state) in global.timelines.iter_mut() {
let flush_lsn = shared_state.sk.wal_store.flush_lsn();
let commit_lsn = shared_state.sk.state.commit_lsn;
os.log_event(format!("tli_loaded;{};{}", flush_lsn.0, commit_lsn.0));
}
let epoll = os.epoll();
loop {
// waiting for the next message
let mut next_event = Some(epoll.recv());
loop {
let event = match next_event {
Some(event) => event,
None => break,
};
match event {
NodeEvent::Accept(tcp) => {
conns.insert(
tcp.id(),
ConnState {
tcp,
greeting: false,
ttid: TenantTimelineId::empty(),
flush_pending: false,
},
);
}
NodeEvent::Message((msg, tcp)) => {
let conn = conns.get_mut(&tcp.id());
if let Some(conn) = conn {
let res = conn.process_any(msg, &mut global);
if res.is_err() {
debug!("conn {:?} error: {:#}", tcp, res.unwrap_err());
conns.remove(&tcp.id());
}
} else {
debug!("conn {:?} was closed, dropping msg {:?}", tcp, msg);
}
}
NodeEvent::Internal(_) => {}
NodeEvent::Closed(_) => {}
NodeEvent::WakeTimeout(_) => {}
}
// TODO: make simulator support multiple events per tick
next_event = epoll.try_recv();
}
conns.retain(|_, conn| {
let res = conn.flush(&mut global);
if res.is_err() {
debug!("conn {:?} error: {:?}", conn.tcp, res);
}
res.is_ok()
});
}
}
impl ConnState {
fn process_any(&mut self, any: AnyMessage, global: &mut GlobalMap) -> Result<()> {
if let AnyMessage::Bytes(copy_data) = any {
let repl_prefix = b"START_REPLICATION ";
if !self.greeting && copy_data.starts_with(repl_prefix) {
self.process_start_replication(copy_data.slice(repl_prefix.len()..), global)?;
bail!("finished processing START_REPLICATION")
}
let msg = ProposerAcceptorMessage::parse(copy_data)?;
debug!("got msg: {:?}", msg);
return self.process(msg, global);
} else {
bail!("unexpected message, expected AnyMessage::Bytes");
}
}
fn process_start_replication(
&mut self,
copy_data: Bytes,
global: &mut GlobalMap,
) -> Result<()> {
// format is "<tenant_id> <timeline_id> <start_lsn> <end_lsn>"
let str = String::from_utf8(copy_data.to_vec())?;
let mut parts = str.split(' ');
let tenant_id = parts.next().unwrap().parse::<TenantId>()?;
let timeline_id = parts.next().unwrap().parse::<TimelineId>()?;
let start_lsn = parts.next().unwrap().parse::<u64>()?;
let end_lsn = parts.next().unwrap().parse::<u64>()?;
let ttid = TenantTimelineId::new(tenant_id, timeline_id);
let shared_state = global.get(&ttid);
// read bytes from start_lsn to end_lsn
let mut buf = vec![0; (end_lsn - start_lsn) as usize];
shared_state.disk.wal.lock().read(start_lsn, &mut buf);
// send bytes to the client
self.tcp.send(AnyMessage::Bytes(Bytes::from(buf)));
Ok(())
}
fn init_timeline(
&mut self,
ttid: TenantTimelineId,
server_info: ServerInfo,
global: &mut GlobalMap,
) -> Result<()> {
self.ttid = ttid;
if global.has_tli(&ttid) {
return Ok(());
}
global.create(ttid, server_info)
}
fn process(&mut self, msg: ProposerAcceptorMessage, global: &mut GlobalMap) -> Result<()> {
if !self.greeting {
self.greeting = true;
match msg {
ProposerAcceptorMessage::Greeting(ref greeting) => {
debug!(
"start handshake with walproposer {:?}",
self.tcp,
);
let server_info = ServerInfo {
pg_version: greeting.pg_version,
system_id: greeting.system_id,
wal_seg_size: greeting.wal_seg_size,
};
let ttid = TenantTimelineId::new(greeting.tenant_id, greeting.timeline_id);
self.init_timeline(ttid, server_info, global)?
}
_ => {
bail!("unexpected message {msg:?} instead of greeting");
}
}
}
let tli = global.get(&self.ttid);
match msg {
ProposerAcceptorMessage::AppendRequest(append_request) => {
self.flush_pending = true;
self.process_sk_msg(
tli,
&ProposerAcceptorMessage::NoFlushAppendRequest(append_request),
)?;
}
other => {
self.process_sk_msg(tli, &other)?;
}
}
Ok(())
}
/// Process FlushWAL if needed.
// TODO: add extra flushes, to verify that extra flushes don't break anything
fn flush(&mut self, global: &mut GlobalMap) -> Result<()> {
if !self.flush_pending {
return Ok(());
}
self.flush_pending = false;
let shared_state = global.get(&self.ttid);
self.process_sk_msg(shared_state, &ProposerAcceptorMessage::FlushWAL)
}
/// Make safekeeper process a message and send a reply to the TCP
fn process_sk_msg(
&mut self,
shared_state: &mut SharedState,
msg: &ProposerAcceptorMessage,
) -> Result<()> {
let mut reply = shared_state.sk.process_msg(msg)?;
if let Some(reply) = &mut reply {
// // if this is AppendResponse, fill in proper hot standby feedback and disk consistent lsn
// if let AcceptorProposerMessage::AppendResponse(ref mut resp) = reply {
// // TODO:
// }
let mut buf = BytesMut::with_capacity(128);
reply.serialize(&mut buf)?;
self.tcp.send(AnyMessage::Bytes(buf.into()));
}
Ok(())
}
}
impl Drop for ConnState {
fn drop(&mut self) {
debug!("dropping conn: {:?}", self.tcp);
if !std::thread::panicking() {
self.tcp.close();
}
// TODO: clean up non-fsynced WAL
}
}

View File

@@ -0,0 +1,38 @@
use std::sync::Arc;
use safekeeper::{
simlib::{
network::{Delay, NetworkOptions},
world::World,
},
simtest::{start_simulation, Options},
};
use crate::{bindings::RunClientC, c_context};
#[test]
fn run_rust_c_test() {
let delay = Delay {
min: 1,
max: 5,
fail_prob: 0.5,
};
let network = NetworkOptions {
keepalive_timeout: Some(50),
connect_delay: delay.clone(),
send_delay: delay.clone(),
};
let u32_data: [u32; 5] = [1, 2, 3, 4, 5];
let world = Arc::new(World::new(1337, Arc::new(network), c_context()));
start_simulation(Options {
world,
time_limit: 1_000_000,
client_fn: Box::new(move |_, server_id| unsafe {
RunClientC(server_id);
}),
u32_data,
});
}

View File

@@ -0,0 +1,234 @@
use std::{ops::Deref, sync::Arc};
use anyhow::Result;
use bytes::{Buf, BytesMut};
use log::{debug, info};
use postgres_ffi::{waldecoder::WalStreamDecoder, XLogSegNo};
use safekeeper::{control_file, safekeeper::SafeKeeperState, wal_storage};
use utils::lsn::Lsn;
use super::disk::TimelineDisk;
pub struct DiskStateStorage {
persisted_state: SafeKeeperState,
disk: Arc<TimelineDisk>,
}
impl DiskStateStorage {
pub fn new(disk: Arc<TimelineDisk>) -> Self {
let guard = disk.state.lock();
let state = guard.clone();
drop(guard);
DiskStateStorage {
persisted_state: state,
disk,
}
}
}
impl control_file::Storage for DiskStateStorage {
fn persist(&mut self, s: &SafeKeeperState) -> Result<()> {
self.persisted_state = s.clone();
*self.disk.state.lock() = s.clone();
Ok(())
}
}
impl Deref for DiskStateStorage {
type Target = SafeKeeperState;
fn deref(&self) -> &Self::Target {
&self.persisted_state
}
}
pub struct DummyWalStore {
lsn: Lsn,
}
impl DummyWalStore {
pub fn new() -> Self {
DummyWalStore { lsn: Lsn::INVALID }
}
}
impl wal_storage::Storage for DummyWalStore {
fn flush_lsn(&self) -> Lsn {
self.lsn
}
fn write_wal(&mut self, startpos: Lsn, buf: &[u8]) -> Result<()> {
self.lsn = startpos + buf.len() as u64;
Ok(())
}
fn truncate_wal(&mut self, end_pos: Lsn) -> Result<()> {
self.lsn = end_pos;
Ok(())
}
fn flush_wal(&mut self) -> Result<()> {
Ok(())
}
fn remove_up_to(&self) -> Box<dyn Fn(XLogSegNo) -> Result<()>> {
Box::new(move |_segno_up_to: XLogSegNo| Ok(()))
}
fn get_metrics(&self) -> safekeeper::metrics::WalStorageMetrics {
safekeeper::metrics::WalStorageMetrics::default()
}
}
pub struct DiskWALStorage {
/// Written to disk, but possibly still in the cache and not fully persisted.
/// Also can be ahead of record_lsn, if happen to be in the middle of a WAL record.
write_lsn: Lsn,
/// The LSN of the last WAL record written to disk. Still can be not fully flushed.
write_record_lsn: Lsn,
/// The LSN of the last WAL record flushed to disk.
flush_record_lsn: Lsn,
/// Decoder is required for detecting boundaries of WAL records.
decoder: WalStreamDecoder,
unflushed_bytes: BytesMut,
disk: Arc<TimelineDisk>,
}
impl DiskWALStorage {
pub fn new(disk: Arc<TimelineDisk>, state: &SafeKeeperState) -> Result<Self> {
let write_lsn = if state.commit_lsn == Lsn(0) {
Lsn(0)
} else {
Self::find_end_of_wal(disk.clone(), state.commit_lsn)?
};
let flush_lsn = write_lsn;
Ok(DiskWALStorage {
write_lsn,
write_record_lsn: flush_lsn,
flush_record_lsn: flush_lsn,
decoder: WalStreamDecoder::new(flush_lsn, 15),
unflushed_bytes: BytesMut::new(),
disk,
})
}
fn find_end_of_wal(disk: Arc<TimelineDisk>, start_lsn: Lsn) -> Result<Lsn> {
let mut buf = [0; 8192];
let mut pos = start_lsn.0;
let mut decoder = WalStreamDecoder::new(start_lsn, 15);
let mut result = start_lsn;
loop {
disk.wal.lock().read(pos, &mut buf);
pos += buf.len() as u64;
decoder.feed_bytes(&buf);
loop {
match decoder.poll_decode() {
Ok(Some(record)) => result = record.0,
Err(e) => {
debug!(
"find_end_of_wal reached end at {:?}, decode error: {:?}",
result, e
);
return Ok(result);
}
Ok(None) => break, // need more data
}
}
}
}
}
impl wal_storage::Storage for DiskWALStorage {
fn flush_lsn(&self) -> Lsn {
self.flush_record_lsn
}
fn write_wal(&mut self, startpos: Lsn, buf: &[u8]) -> Result<()> {
if self.write_lsn != startpos {
panic!("write_wal called with wrong startpos");
}
self.unflushed_bytes.extend_from_slice(buf);
self.write_lsn += buf.len() as u64;
if self.decoder.available() != startpos {
info!(
"restart decoder from {} to {}",
self.decoder.available(),
startpos,
);
self.decoder = WalStreamDecoder::new(startpos, 15);
}
self.decoder.feed_bytes(buf);
loop {
match self.decoder.poll_decode()? {
None => break, // no full record yet
Some((lsn, _rec)) => {
self.write_record_lsn = lsn;
}
}
}
Ok(())
}
fn truncate_wal(&mut self, end_pos: Lsn) -> Result<()> {
if self.write_lsn != Lsn(0) && end_pos > self.write_lsn {
panic!(
"truncate_wal called on non-written WAL, write_lsn={}, end_pos={}",
self.write_lsn, end_pos
);
}
self.flush_wal()?;
// write zeroes to disk from end_pos until self.write_lsn
let buf = [0; 8192];
let mut pos = end_pos.0;
while pos < self.write_lsn.0 {
self.disk.wal.lock().write(pos, &buf);
pos += buf.len() as u64;
}
self.write_lsn = end_pos;
self.write_record_lsn = end_pos;
self.flush_record_lsn = end_pos;
self.unflushed_bytes.clear();
self.decoder = WalStreamDecoder::new(end_pos, 15);
Ok(())
}
fn flush_wal(&mut self) -> Result<()> {
if self.flush_record_lsn == self.write_record_lsn {
// no need to do extra flush
return Ok(());
}
let num_bytes = self.write_record_lsn.0 - self.flush_record_lsn.0;
self.disk.wal.lock().write(
self.flush_record_lsn.0,
&self.unflushed_bytes[..num_bytes as usize],
);
self.unflushed_bytes.advance(num_bytes as usize);
self.flush_record_lsn = self.write_record_lsn;
Ok(())
}
fn remove_up_to(&self) -> Box<dyn Fn(XLogSegNo) -> Result<()>> {
Box::new(move |_segno_up_to: XLogSegNo| Ok(()))
}
fn get_metrics(&self) -> safekeeper::metrics::WalStorageMetrics {
safekeeper::metrics::WalStorageMetrics::default()
}
}

View File

@@ -0,0 +1,610 @@
use std::{ffi::CString, path::Path, str::FromStr, sync::Arc, collections::HashMap};
use rand::{Rng, SeedableRng};
use safekeeper::simlib::{
network::{Delay, NetworkOptions},
proto::AnyMessage,
time::EmptyEvent,
world::World,
world::{Node, NodeEvent, SEvent, NodeId},
};
use tracing::{debug, error, info, warn};
use utils::{id::TenantTimelineId, lsn::Lsn};
use crate::{
bindings::{
neon_tenant_walproposer, neon_timeline_walproposer, sim_redo_start_lsn, syncSafekeepers,
wal_acceptor_connection_timeout, wal_acceptor_reconnect_timeout, wal_acceptors_list,
MyInsertRecord, WalProposerCleanup, WalProposerRust,
},
c_context,
simtest::{
log::{init_logger, SimClock},
safekeeper::run_server,
},
};
use super::disk::Disk;
pub struct SkNode {
pub node: Arc<Node>,
pub id: u32,
pub disk: Arc<Disk>,
}
impl SkNode {
pub fn new(node: Arc<Node>) -> Self {
let disk = Arc::new(Disk::new());
let res = Self {
id: node.id,
node,
disk,
};
res.launch();
res
}
pub fn launch(&self) {
let id = self.id;
let disk = self.disk.clone();
// start the server thread
self.node.launch(move |os| {
let res = run_server(os, disk);
debug!("server {} finished: {:?}", id, res);
});
}
pub fn restart(&self) {
self.node.crash_stop();
self.launch();
}
}
pub struct TestConfig {
pub network: NetworkOptions,
pub timeout: u64,
pub clock: Option<SimClock>,
}
impl TestConfig {
pub fn new(clock: Option<SimClock>) -> Self {
Self {
network: NetworkOptions {
keepalive_timeout: Some(2000),
connect_delay: Delay {
min: 1,
max: 5,
fail_prob: 0.0,
},
send_delay: Delay {
min: 1,
max: 5,
fail_prob: 0.0,
},
},
timeout: 1_000 * 10,
clock,
}
}
pub fn start(&self, seed: u64) -> Test {
let world = Arc::new(World::new(
seed,
Arc::new(self.network.clone()),
c_context(),
));
world.register_world();
if let Some(clock) = &self.clock {
clock.set_world(world.clone());
}
let servers = [
SkNode::new(world.new_node()),
SkNode::new(world.new_node()),
SkNode::new(world.new_node()),
];
let server_ids = [servers[0].id, servers[1].id, servers[2].id];
let safekeepers_guc = server_ids.map(|id| format!("node:{}", id)).join(",");
let ttid = TenantTimelineId::generate();
// wait init for all servers
world.await_all();
// clean up pgdata directory
self.init_pgdata();
Test {
world,
servers,
safekeepers_guc,
ttid,
timeout: self.timeout,
}
}
pub fn init_pgdata(&self) {
let pgdata = Path::new("/home/admin/simulator/libs/walproposer/pgdata");
if pgdata.exists() {
std::fs::remove_dir_all(pgdata).unwrap();
}
std::fs::create_dir(pgdata).unwrap();
// create empty pg_wal and pg_notify subdirs
std::fs::create_dir(pgdata.join("pg_wal")).unwrap();
std::fs::create_dir(pgdata.join("pg_notify")).unwrap();
// write postgresql.conf
let mut conf = std::fs::File::create(pgdata.join("postgresql.conf")).unwrap();
let content = "
wal_log_hints=off
hot_standby=on
fsync=off
wal_level=replica
restart_after_crash=off
shared_preload_libraries=neon
neon.pageserver_connstring=''
neon.tenant_id=cc6e67313d57283bad411600fbf5c142
neon.timeline_id=de6fa815c1e45aa61491c3d34c4eb33e
synchronous_standby_names=walproposer
neon.safekeepers='node:1,node:2,node:3'
max_connections=100
";
std::io::Write::write_all(&mut conf, content.as_bytes()).unwrap();
}
}
pub struct Test {
pub world: Arc<World>,
pub servers: [SkNode; 3],
pub safekeepers_guc: String,
pub ttid: TenantTimelineId,
pub timeout: u64,
}
impl Test {
fn launch_sync(&self) -> Arc<Node> {
let client_node = self.world.new_node();
debug!("sync-safekeepers started at node {}", client_node.id);
// start the client thread
let guc = self.safekeepers_guc.clone();
let ttid = self.ttid.clone();
client_node.launch(move |_| {
let list = CString::new(guc).unwrap();
unsafe {
WalProposerCleanup();
syncSafekeepers = true;
wal_acceptors_list = list.into_raw();
wal_acceptor_reconnect_timeout = 1000;
wal_acceptor_connection_timeout = 5000;
neon_tenant_walproposer =
CString::new(ttid.tenant_id.to_string()).unwrap().into_raw();
neon_timeline_walproposer = CString::new(ttid.timeline_id.to_string())
.unwrap()
.into_raw();
WalProposerRust();
}
});
self.world.await_all();
client_node
}
pub fn sync_safekeepers(&self) -> anyhow::Result<Lsn> {
let client_node = self.launch_sync();
// poll until exit or timeout
let time_limit = self.timeout;
while self.world.step() && self.world.now() < time_limit && !client_node.is_finished() {}
if !client_node.is_finished() {
anyhow::bail!("timeout or idle stuck");
}
let res = client_node.result.lock().clone();
if res.0 != 0 {
anyhow::bail!("non-zero exitcode: {:?}", res);
}
let lsn = Lsn::from_str(&res.1)?;
Ok(lsn)
}
pub fn launch_walproposer(&self, lsn: Lsn) -> WalProposer {
let client_node = self.world.new_node();
let lsn = if lsn.0 == 0 {
// usual LSN after basebackup
Lsn(21623024)
} else {
lsn
};
// start the client thread
let guc = self.safekeepers_guc.clone();
let ttid = self.ttid.clone();
client_node.launch(move |_| {
let list = CString::new(guc).unwrap();
unsafe {
WalProposerCleanup();
sim_redo_start_lsn = lsn.0;
syncSafekeepers = false;
wal_acceptors_list = list.into_raw();
wal_acceptor_reconnect_timeout = 1000;
wal_acceptor_connection_timeout = 5000;
neon_tenant_walproposer =
CString::new(ttid.tenant_id.to_string()).unwrap().into_raw();
neon_timeline_walproposer = CString::new(ttid.timeline_id.to_string())
.unwrap()
.into_raw();
WalProposerRust();
}
});
self.world.await_all();
WalProposer {
node: client_node,
}
}
pub fn poll_for_duration(&self, duration: u64) {
let time_limit = std::cmp::min(self.world.now() + duration, self.timeout);
while self.world.step() && self.world.now() < time_limit {}
}
pub fn run_schedule(&self, schedule: &Schedule) -> anyhow::Result<()> {
{
let empty_event = Box::new(EmptyEvent);
let now = self.world.now();
for (time, _) in schedule {
if *time < now {
continue;
}
self.world.schedule(*time - now, empty_event.clone())
}
}
let mut wait_node = self.launch_sync();
// fake walproposer
let mut wp = WalProposer {
node: wait_node.clone(),
};
let mut sync_in_progress = true;
let mut skipped_tx = 0;
let mut started_tx = 0;
let mut schedule_ptr = 0;
loop {
if sync_in_progress && wait_node.is_finished() {
let res = wait_node.result.lock().clone();
if res.0 != 0 {
warn!("sync non-zero exitcode: {:?}", res);
debug!("restarting walproposer");
wait_node = self.launch_sync();
continue;
}
let lsn = Lsn::from_str(&res.1)?;
debug!("sync-safekeepers finished at LSN {}", lsn);
wp = self.launch_walproposer(lsn);
wait_node = wp.node.clone();
debug!("walproposer started at node {}", wait_node.id);
sync_in_progress = false;
}
let now = self.world.now();
while schedule_ptr < schedule.len() && schedule[schedule_ptr].0 <= now {
if now != schedule[schedule_ptr].0 {
warn!("skipped event {:?} at {}", schedule[schedule_ptr], now);
}
let action = &schedule[schedule_ptr].1;
match action {
TestAction::WriteTx(size) => {
if !sync_in_progress && !wait_node.is_finished() {
started_tx += *size;
wp.write_tx(*size);
debug!("written {} transactions", size);
} else {
skipped_tx += size;
debug!("skipped {} transactions", size);
}
}
TestAction::RestartSafekeeper(id) => {
debug!("restarting safekeeper {}", id);
self.servers[*id as usize].restart();
}
TestAction::RestartWalProposer => {
debug!("restarting walproposer");
wait_node.crash_stop();
sync_in_progress = true;
wait_node = self.launch_sync();
}
}
schedule_ptr += 1;
}
if schedule_ptr == schedule.len() {
break;
}
let next_event_time = schedule[schedule_ptr].0;
// poll until the next event
if wait_node.is_finished() {
while self.world.step() && self.world.now() < next_event_time {}
} else {
while self.world.step()
&& self.world.now() < next_event_time
&& !wait_node.is_finished()
{}
}
}
debug!("finished schedule");
debug!("skipped_tx: {}", skipped_tx);
debug!("started_tx: {}", started_tx);
Ok(())
}
}
pub struct WalProposer {
pub node: Arc<Node>,
}
impl WalProposer {
pub fn write_tx(&mut self, cnt: usize) {
self.node
.network_chan()
.send(NodeEvent::Internal(AnyMessage::Just32(cnt as u32)));
}
pub fn stop(&self) {
self.node.crash_stop();
}
}
#[derive(Debug, Clone)]
pub enum TestAction {
WriteTx(usize),
RestartSafekeeper(usize),
RestartWalProposer,
}
pub type Schedule = Vec<(u64, TestAction)>;
pub fn generate_schedule(seed: u64) -> Schedule {
let mut rng = rand::rngs::StdRng::seed_from_u64(seed);
let mut schedule = Vec::new();
let mut time = 0;
let cnt = rng.gen_range(1..100);
for _ in 0..cnt {
time += rng.gen_range(0..500);
let action = match rng.gen_range(0..3) {
0 => TestAction::WriteTx(rng.gen_range(1..10)),
1 => TestAction::RestartSafekeeper(rng.gen_range(0..3)),
2 => TestAction::RestartWalProposer,
_ => unreachable!(),
};
schedule.push((time, action));
}
schedule
}
pub fn generate_network_opts(seed: u64) -> NetworkOptions {
let mut rng = rand::rngs::StdRng::seed_from_u64(seed);
let timeout = rng.gen_range(100..2000);
let max_delay = rng.gen_range(1..2*timeout);
let min_delay = rng.gen_range(1..=max_delay);
let max_fail_prob = rng.gen_range(0.0..0.9);
let connect_fail_prob = rng.gen_range(0.0..max_fail_prob);
let send_fail_prob = rng.gen_range(0.0..connect_fail_prob);
NetworkOptions {
keepalive_timeout: Some(timeout),
connect_delay: Delay {
min: min_delay,
max: max_delay,
fail_prob: connect_fail_prob,
},
send_delay: Delay {
min: min_delay,
max: max_delay,
fail_prob: send_fail_prob,
},
}
}
#[derive(Debug,Clone,PartialEq,Eq)]
enum NodeKind {
Unknown,
Safekeeper,
WalProposer,
}
impl Default for NodeKind {
fn default() -> Self {
Self::Unknown
}
}
#[derive(Clone, Debug, Default)]
struct NodeInfo {
kind: NodeKind,
// walproposer
is_sync: bool,
term: u64,
epoch_lsn: u64,
// safekeeper
commit_lsn: u64,
flush_lsn: u64,
}
impl NodeInfo {
fn init_kind(&mut self, kind: NodeKind) {
if self.kind == NodeKind::Unknown {
self.kind = kind;
} else {
assert!(self.kind == kind);
}
}
fn started(&mut self, data: &str) {
let mut parts = data.split(';');
assert!(parts.next().unwrap() == "started");
match parts.next().unwrap() {
"safekeeper" => {
self.init_kind(NodeKind::Safekeeper);
}
"walproposer" => {
self.init_kind(NodeKind::WalProposer);
let is_sync: u8 = parts.next().unwrap().parse().unwrap();
self.is_sync = is_sync != 0;
}
_ => unreachable!(),
}
}
}
#[derive(Debug,Default)]
struct GlobalState {
nodes: Vec<NodeInfo>,
commit_lsn: u64,
write_lsn: u64,
max_write_lsn: u64,
written_wal: u64,
written_records: u64,
}
impl GlobalState {
fn new() -> Self {
Default::default()
}
fn get(&mut self, id: u32) -> &mut NodeInfo {
let id = id as usize;
if id >= self.nodes.len() {
self.nodes.resize(id + 1, NodeInfo::default());
}
&mut self.nodes[id]
}
}
pub fn validate_events(events: Vec<SEvent>) {
const INITDB_LSN: u64 = 21623024;
let hook = std::panic::take_hook();
scopeguard::defer_on_success! {
std::panic::set_hook(hook);
};
let mut state = GlobalState::new();
state.max_write_lsn = INITDB_LSN;
for event in events {
debug!("{:?}", event);
let node = state.get(event.node);
if event.data.starts_with("started;") {
node.started(&event.data);
continue;
}
assert!(node.kind != NodeKind::Unknown);
// drop reference to unlock state
let mut node = node.clone();
let mut parts = event.data.split(';');
match node.kind {
NodeKind::Safekeeper => {
match parts.next().unwrap() {
"tli_loaded" => {
let flush_lsn: u64 = parts.next().unwrap().parse().unwrap();
let commit_lsn: u64 = parts.next().unwrap().parse().unwrap();
node.flush_lsn = flush_lsn;
node.commit_lsn = commit_lsn;
}
_ => unreachable!(),
}
}
NodeKind::WalProposer => {
match parts.next().unwrap() {
"prop_elected" => {
let prop_lsn: u64 = parts.next().unwrap().parse().unwrap();
let prop_term: u64 = parts.next().unwrap().parse().unwrap();
let prev_lsn: u64 = parts.next().unwrap().parse().unwrap();
let prev_term: u64 = parts.next().unwrap().parse().unwrap();
assert!(prop_lsn >= prev_lsn);
assert!(prop_term >= prev_term);
assert!(prop_lsn >= state.commit_lsn);
if prop_lsn > state.write_lsn {
assert!(prop_lsn <= state.max_write_lsn);
debug!("moving write_lsn up from {} to {}", state.write_lsn, prop_lsn);
state.write_lsn = prop_lsn;
}
if prop_lsn < state.write_lsn {
debug!("moving write_lsn down from {} to {}", state.write_lsn, prop_lsn);
state.write_lsn = prop_lsn;
}
node.epoch_lsn = prop_lsn;
node.term = prop_term;
}
"write_wal" => {
assert!(!node.is_sync);
let start_lsn: u64 = parts.next().unwrap().parse().unwrap();
let end_lsn: u64 = parts.next().unwrap().parse().unwrap();
let cnt: u64 = parts.next().unwrap().parse().unwrap();
let size = end_lsn - start_lsn;
state.written_wal += size;
state.written_records += cnt;
// TODO: If we allow writing WAL before winning the election
assert!(start_lsn >= state.commit_lsn);
assert!(end_lsn >= start_lsn);
assert!(start_lsn == state.write_lsn);
state.write_lsn = end_lsn;
if end_lsn > state.max_write_lsn {
state.max_write_lsn = end_lsn;
}
}
"commit_lsn" => {
let lsn: u64 = parts.next().unwrap().parse().unwrap();
assert!(lsn >= state.commit_lsn);
state.commit_lsn = lsn;
}
_ => unreachable!(),
}
}
_ => unreachable!(),
}
// update the node in the state struct
*state.get(event.node) = node;
}
}

View File

@@ -0,0 +1,265 @@
use std::{ffi::CString, path::Path, str::FromStr, sync::Arc};
use rand::Rng;
use safekeeper::simlib::{
network::{Delay, NetworkOptions},
proto::AnyMessage,
world::World,
world::{Node, NodeEvent},
};
use tracing::{info, warn};
use utils::{id::TenantTimelineId, lsn::Lsn};
use crate::{
bindings::{
neon_tenant_walproposer, neon_timeline_walproposer, sim_redo_start_lsn, syncSafekeepers,
wal_acceptor_connection_timeout, wal_acceptor_reconnect_timeout, wal_acceptors_list,
MyInsertRecord, WalProposerCleanup, WalProposerRust,
},
c_context,
simtest::{
log::{init_logger, SimClock},
safekeeper::run_server,
util::{generate_schedule, TestConfig, generate_network_opts, validate_events},
}, enable_debug,
};
use super::{
disk::Disk,
util::{Schedule, TestAction},
};
#[test]
fn sync_empty_safekeepers() {
let clock = init_logger();
let mut config = TestConfig::new(Some(clock));
let test = config.start(1337);
let lsn = test.sync_safekeepers().unwrap();
assert_eq!(lsn, Lsn(0));
info!("Sucessfully synced empty safekeepers at 0/0");
let lsn = test.sync_safekeepers().unwrap();
assert_eq!(lsn, Lsn(0));
info!("Sucessfully synced (again) empty safekeepers at 0/0");
}
#[test]
fn run_walproposer_generate_wal() {
let clock = init_logger();
let mut config = TestConfig::new(Some(clock));
// config.network.timeout = Some(250);
let test = config.start(1337);
let lsn = test.sync_safekeepers().unwrap();
assert_eq!(lsn, Lsn(0));
info!("Sucessfully synced empty safekeepers at 0/0");
let mut wp = test.launch_walproposer(lsn);
test.poll_for_duration(30);
for i in 0..100 {
wp.write_tx(1);
test.poll_for_duration(5);
}
}
#[test]
fn crash_safekeeper() {
let clock = init_logger();
let mut config = TestConfig::new(Some(clock));
// config.network.timeout = Some(250);
let test = config.start(1337);
let lsn = test.sync_safekeepers().unwrap();
assert_eq!(lsn, Lsn(0));
info!("Sucessfully synced empty safekeepers at 0/0");
let mut wp = test.launch_walproposer(lsn);
test.poll_for_duration(30);
wp.write_tx(3);
test.servers[0].restart();
test.poll_for_duration(100);
test.poll_for_duration(1000);
}
#[test]
fn test_simple_restart() {
let clock = init_logger();
let mut config = TestConfig::new(Some(clock));
// config.network.timeout = Some(250);
let test = config.start(1337);
let lsn = test.sync_safekeepers().unwrap();
assert_eq!(lsn, Lsn(0));
info!("Sucessfully synced empty safekeepers at 0/0");
let mut wp = test.launch_walproposer(lsn);
test.poll_for_duration(30);
wp.write_tx(3);
test.poll_for_duration(100);
wp.stop();
drop(wp);
let lsn = test.sync_safekeepers().unwrap();
info!("Sucessfully synced safekeepers at {}", lsn);
}
#[test]
fn test_simple_schedule() -> anyhow::Result<()> {
let clock = init_logger();
let mut config = TestConfig::new(Some(clock));
config.network.keepalive_timeout = Some(100);
let test = config.start(1337);
let schedule: Schedule = vec![
(0, TestAction::RestartWalProposer),
(50, TestAction::WriteTx(5)),
(100, TestAction::RestartSafekeeper(0)),
(100, TestAction::WriteTx(5)),
(110, TestAction::RestartSafekeeper(1)),
(110, TestAction::WriteTx(5)),
(120, TestAction::RestartSafekeeper(2)),
(120, TestAction::WriteTx(5)),
(201, TestAction::RestartWalProposer),
(251, TestAction::RestartSafekeeper(0)),
(251, TestAction::RestartSafekeeper(1)),
(251, TestAction::RestartSafekeeper(2)),
(251, TestAction::WriteTx(5)),
(255, TestAction::WriteTx(5)),
(1000, TestAction::WriteTx(5)),
];
test.run_schedule(&schedule)?;
info!("Test finished, stopping all threads");
test.world.deallocate();
Ok(())
}
#[test]
fn test_many_tx() -> anyhow::Result<()> {
enable_debug();
let clock = init_logger();
let mut config = TestConfig::new(Some(clock));
let test = config.start(1337);
let mut schedule: Schedule = vec![];
for i in 0..100 {
schedule.push((i * 10, TestAction::WriteTx(10)));
}
test.run_schedule(&schedule)?;
info!("Test finished, stopping all threads");
test.world.stop_all();
let events = test.world.take_events();
info!("Events: {:?}", events);
let last_commit_lsn = events
.iter()
.filter_map(|event| {
if event.data.starts_with("commit_lsn;") {
let lsn: u64 = event.data.split(';').nth(1).unwrap().parse().unwrap();
return Some(lsn);
}
None
})
.last()
.unwrap();
let initdb_lsn = 21623024;
let diff = last_commit_lsn - initdb_lsn;
info!("Last commit lsn: {}, diff: {}", last_commit_lsn, diff);
assert!(diff > 1000 * 8);
Ok(())
}
#[test]
fn test_random_schedules() -> anyhow::Result<()> {
let clock = init_logger();
let mut config = TestConfig::new(Some(clock));
config.network.keepalive_timeout = Some(100);
for i in 0..30000 {
let seed: u64 = rand::thread_rng().gen();
config.network = generate_network_opts(seed);
let test = config.start(seed);
warn!("Running test with seed {}", seed);
let schedule = generate_schedule(seed);
test.run_schedule(&schedule).unwrap();
validate_events(test.world.take_events());
test.world.deallocate();
}
Ok(())
}
#[test]
fn test_one_schedule() -> anyhow::Result<()> {
enable_debug();
let clock = init_logger();
let mut config = TestConfig::new(Some(clock));
config.network.keepalive_timeout = Some(100);
// let seed = 6762900106769428342;
// let test = config.start(seed);
// warn!("Running test with seed {}", seed);
// let schedule = generate_schedule(seed);
// info!("schedule: {:?}", schedule);
// test.run_schedule(&schedule)?;
// test.world.deallocate();
let seed = 3649773280641776194;
config.network = generate_network_opts(seed);
info!("network: {:?}", config.network);
let test = config.start(seed);
warn!("Running test with seed {}", seed);
let schedule = generate_schedule(seed);
info!("schedule: {:?}", schedule);
test.run_schedule(&schedule).unwrap();
validate_events(test.world.take_events());
test.world.deallocate();
Ok(())
}
#[test]
fn test_res_dealloc() -> anyhow::Result<()> {
// enable_debug();
let clock = init_logger();
let mut config = TestConfig::new(Some(clock));
// print pid
let pid = unsafe { libc::getpid() };
info!("pid: {}", pid);
let seed = 123456;
config.network = generate_network_opts(seed);
let test = config.start(seed);
warn!("Running test with seed {}", seed);
let schedule = generate_schedule(seed);
info!("schedule: {:?}", schedule);
test.run_schedule(&schedule).unwrap();
test.world.stop_all();
let world = test.world.clone();
drop(test);
info!("world strong count: {}", Arc::strong_count(&world));
world.deallocate();
info!("world strong count: {}", Arc::strong_count(&world));
Ok(())
}

View File

@@ -0,0 +1,31 @@
use tracing::info;
use crate::bindings::{TestFunc, MyContextInit};
#[test]
fn test_rust_c_calls() {
let res = std::thread::spawn(|| {
let res = unsafe {
MyContextInit();
TestFunc(1, 2)
};
res
}).join().unwrap();
info!("res: {}", res);
}
#[test]
fn test_sim_bindings() {
std::thread::spawn(|| {
unsafe {
MyContextInit();
TestFunc(1, 2)
}
}).join().unwrap();
std::thread::spawn(|| {
unsafe {
MyContextInit();
TestFunc(1, 2)
}
}).join().unwrap();
}

100
libs/walproposer/test.c Normal file
View File

@@ -0,0 +1,100 @@
#include "bindgen_deps.h"
#include "rust_bindings.h"
#include <stdio.h>
#include <pthread.h>
#include <stdlib.h>
#include "postgres.h"
#include "utils/memutils.h"
#include "utils/guc.h"
#include "miscadmin.h"
#include "common/pg_prng.h"
// From src/backend/main/main.c
const char *progname = "fakepostgres";
int TestFunc(int a, int b) {
printf("TestFunc: %d + %d = %d\n", a, b, a + b);
rust_function(0);
elog(LOG, "postgres elog test");
printf("After rust_function\n");
return a + b;
}
// This is a quick experiment with rewriting existing Rust code in C.
void RunClientC(uint32_t serverId) {
uint32_t clientId = sim_id();
elog(LOG, "started client");
int data_len = 5;
int delivered = 0;
int tcp = sim_open_tcp(serverId);
while (delivered < data_len) {
sim_msg_set_repl_cell(delivered+1, clientId, delivered);
sim_tcp_send(tcp);
Event event = sim_epoll_rcv(-1);
switch (event.tag)
{
case Closed:
elog(LOG, "connection closed");
tcp = sim_open_tcp(serverId);
break;
case Message:
Assert(event.any_message == Just32);
uint32_t val;
sim_msg_get_just_u32(&val);
if (val == delivered + 1) {
delivered += 1;
}
break;
default:
Assert(false);
}
}
}
bool debug_enabled = false;
bool initializedMemoryContext = false;
// pthread_mutex_init(&lock, NULL)?
pthread_mutex_t lock;
void MyContextInit() {
// initializes global variables, TODO how to make them thread-local?
pthread_mutex_lock(&lock);
if (!initializedMemoryContext) {
initializedMemoryContext = true;
MemoryContextInit();
pg_prng_seed(&pg_global_prng_state, 0);
setenv("PGDATA", "/home/admin/simulator/libs/walproposer/pgdata", 1);
/*
* Set default values for command-line options.
*/
InitializeGUCOptions();
/* Acquire configuration parameters */
if (!SelectConfigFiles(NULL, progname))
exit(1);
if (debug_enabled) {
log_min_messages = LOG;
} else {
log_min_messages = FATAL;
}
Log_line_prefix = "[%p] ";
InitializeMaxBackends();
ChangeToDataDir();
CreateSharedMemoryAndSemaphores();
SetInstallXLogFileSegmentActive();
// CreateAuxProcessResourceOwner();
// StartupXLOG();
}
pthread_mutex_unlock(&lock);
}

View File

@@ -37,7 +37,6 @@ num-traits.workspace = true
once_cell.workspace = true
pin-project-lite.workspace = true
postgres.workspace = true
postgres_backend.workspace = true
postgres-protocol.workspace = true
postgres-types.workspace = true
rand.workspace = true

View File

@@ -23,10 +23,11 @@ use pageserver::{
tenant::mgr,
virtual_file,
};
use postgres_backend::AuthType;
use utils::{
auth::JwtAuth,
logging, project_git_version,
logging,
postgres_backend::AuthType,
project_git_version,
sentry_init::init_sentry,
signals::{self, Signal},
tcp_listener,
@@ -270,43 +271,43 @@ fn start_pageserver(
WALRECEIVER_RUNTIME.block_on(pageserver::broker_client::init_broker_client(conf))?;
// Initialize authentication for incoming connections
let http_auth;
let pg_auth;
if conf.http_auth_type == AuthType::NeonJWT || conf.pg_auth_type == AuthType::NeonJWT {
// unwrap is ok because check is performed when creating config, so path is set and file exists
let key_path = conf.auth_validation_public_key_path.as_ref().unwrap();
info!(
"Loading public key for verifying JWT tokens from {:#?}",
key_path
);
let auth: Arc<JwtAuth> = Arc::new(JwtAuth::from_key_path(key_path)?);
let auth = match &conf.auth_type {
AuthType::Trust => None,
AuthType::NeonJWT => {
// unwrap is ok because check is performed when creating config, so path is set and file exists
let key_path = conf.auth_validation_public_key_path.as_ref().unwrap();
Some(JwtAuth::from_key_path(key_path)?.into())
}
};
info!("Using auth: {:#?}", conf.auth_type);
http_auth = match &conf.http_auth_type {
AuthType::Trust => None,
AuthType::NeonJWT => Some(auth.clone()),
};
pg_auth = match &conf.pg_auth_type {
AuthType::Trust => None,
AuthType::NeonJWT => Some(auth),
};
} else {
http_auth = None;
pg_auth = None;
}
info!("Using auth for http API: {:#?}", conf.http_auth_type);
info!("Using auth for pg connections: {:#?}", conf.pg_auth_type);
match var("NEON_AUTH_TOKEN") {
Ok(v) => {
// TODO: remove ZENITH_AUTH_TOKEN once it's not used anywhere in development/staging/prod configuration.
match (var("ZENITH_AUTH_TOKEN"), var("NEON_AUTH_TOKEN")) {
(old, Ok(v)) => {
info!("Loaded JWT token for authentication with Safekeeper");
if let Ok(v_old) = old {
warn!(
"JWT token for Safekeeper is specified twice, ZENITH_AUTH_TOKEN is deprecated"
);
if v_old != v {
warn!("JWT token for Safekeeper has two different values, choosing NEON_AUTH_TOKEN");
}
}
pageserver::config::SAFEKEEPER_AUTH_TOKEN
.set(Arc::new(v))
.map_err(|_| anyhow!("Could not initialize SAFEKEEPER_AUTH_TOKEN"))?;
}
Err(VarError::NotPresent) => {
(Ok(v), _) => {
info!("Loaded JWT token for authentication with Safekeeper");
warn!("Please update pageserver configuration: the JWT token should be NEON_AUTH_TOKEN, not ZENITH_AUTH_TOKEN");
pageserver::config::SAFEKEEPER_AUTH_TOKEN
.set(Arc::new(v))
.map_err(|_| anyhow!("Could not initialize SAFEKEEPER_AUTH_TOKEN"))?;
}
(_, Err(VarError::NotPresent)) => {
info!("No JWT token for authentication with Safekeeper detected");
}
Err(e) => {
(_, Err(e)) => {
return Err(e).with_context(|| {
"Failed to either load to detect non-present NEON_AUTH_TOKEN environment variable"
})
@@ -324,7 +325,7 @@ fn start_pageserver(
{
let _rt_guard = MGMT_REQUEST_RUNTIME.enter();
let router = http::make_router(conf, launch_ts, http_auth, remote_storage)?
let router = http::make_router(conf, launch_ts, auth.clone(), remote_storage)?
.build()
.map_err(|err| anyhow!(err))?;
let service = utils::http::RouterService::new(router).unwrap();
@@ -398,9 +399,9 @@ fn start_pageserver(
async move {
page_service::libpq_listener_main(
conf,
pg_auth,
auth,
pageserver_listener,
conf.pg_auth_type,
conf.auth_type,
libpq_ctx,
)
.await

View File

@@ -21,10 +21,10 @@ use std::time::Duration;
use toml_edit;
use toml_edit::{Document, Item};
use postgres_backend::AuthType;
use utils::{
id::{NodeId, TenantId, TimelineId},
logging::LogFormat,
postgres_backend::AuthType,
};
use crate::tenant::config::TenantConf;
@@ -118,9 +118,6 @@ pub struct PageServerConf {
/// Example (default): 127.0.0.1:9898
pub listen_http_addr: String,
/// Current availability zone. Used for traffic metrics.
pub availability_zone: Option<String>,
// Timeout when waiting for WAL receiver to catch up to an LSN given in a GetPage@LSN call.
pub wait_lsn_timeout: Duration,
// How long to wait for WAL redo to complete.
@@ -141,15 +138,9 @@ pub struct PageServerConf {
pub pg_distrib_dir: PathBuf,
// Authentication
/// authentication method for the HTTP mgmt API
pub http_auth_type: AuthType,
/// authentication method for libpq connections from compute
pub pg_auth_type: AuthType,
/// Path to a file containing public key for verifying JWT tokens.
/// Used for both mgmt and compute auth, if enabled.
pub auth_validation_public_key_path: Option<PathBuf>,
pub auth_type: AuthType,
pub auth_validation_public_key_path: Option<PathBuf>,
pub remote_storage_config: Option<RemoteStorageConfig>,
pub default_tenant_conf: TenantConf,
@@ -205,8 +196,6 @@ struct PageServerConfigBuilder {
listen_http_addr: BuilderValue<String>,
availability_zone: BuilderValue<Option<String>>,
wait_lsn_timeout: BuilderValue<Duration>,
wal_redo_timeout: BuilderValue<Duration>,
@@ -219,8 +208,7 @@ struct PageServerConfigBuilder {
pg_distrib_dir: BuilderValue<PathBuf>,
http_auth_type: BuilderValue<AuthType>,
pg_auth_type: BuilderValue<AuthType>,
auth_type: BuilderValue<AuthType>,
//
auth_validation_public_key_path: BuilderValue<Option<PathBuf>>,
@@ -252,7 +240,6 @@ impl Default for PageServerConfigBuilder {
Self {
listen_pg_addr: Set(DEFAULT_PG_LISTEN_ADDR.to_string()),
listen_http_addr: Set(DEFAULT_HTTP_LISTEN_ADDR.to_string()),
availability_zone: Set(None),
wait_lsn_timeout: Set(humantime::parse_duration(DEFAULT_WAIT_LSN_TIMEOUT)
.expect("cannot parse default wait lsn timeout")),
wal_redo_timeout: Set(humantime::parse_duration(DEFAULT_WAL_REDO_TIMEOUT)
@@ -264,8 +251,7 @@ impl Default for PageServerConfigBuilder {
pg_distrib_dir: Set(env::current_dir()
.expect("cannot access current directory")
.join("pg_install")),
http_auth_type: Set(AuthType::Trust),
pg_auth_type: Set(AuthType::Trust),
auth_type: Set(AuthType::Trust),
auth_validation_public_key_path: Set(None),
remote_storage_config: Set(None),
id: NotSet,
@@ -309,10 +295,6 @@ impl PageServerConfigBuilder {
self.listen_http_addr = BuilderValue::Set(listen_http_addr)
}
pub fn availability_zone(&mut self, availability_zone: Option<String>) {
self.availability_zone = BuilderValue::Set(availability_zone)
}
pub fn wait_lsn_timeout(&mut self, wait_lsn_timeout: Duration) {
self.wait_lsn_timeout = BuilderValue::Set(wait_lsn_timeout)
}
@@ -341,12 +323,8 @@ impl PageServerConfigBuilder {
self.pg_distrib_dir = BuilderValue::Set(pg_distrib_dir)
}
pub fn http_auth_type(&mut self, auth_type: AuthType) {
self.http_auth_type = BuilderValue::Set(auth_type)
}
pub fn pg_auth_type(&mut self, auth_type: AuthType) {
self.pg_auth_type = BuilderValue::Set(auth_type)
pub fn auth_type(&mut self, auth_type: AuthType) {
self.auth_type = BuilderValue::Set(auth_type)
}
pub fn auth_validation_public_key_path(
@@ -424,9 +402,6 @@ impl PageServerConfigBuilder {
listen_http_addr: self
.listen_http_addr
.ok_or(anyhow!("missing listen_http_addr"))?,
availability_zone: self
.availability_zone
.ok_or(anyhow!("missing availability_zone"))?,
wait_lsn_timeout: self
.wait_lsn_timeout
.ok_or(anyhow!("missing wait_lsn_timeout"))?,
@@ -444,10 +419,7 @@ impl PageServerConfigBuilder {
pg_distrib_dir: self
.pg_distrib_dir
.ok_or(anyhow!("missing pg_distrib_dir"))?,
http_auth_type: self
.http_auth_type
.ok_or(anyhow!("missing http_auth_type"))?,
pg_auth_type: self.pg_auth_type.ok_or(anyhow!("missing pg_auth_type"))?,
auth_type: self.auth_type.ok_or(anyhow!("missing auth_type"))?,
auth_validation_public_key_path: self
.auth_validation_public_key_path
.ok_or(anyhow!("missing auth_validation_public_key_path"))?,
@@ -627,7 +599,6 @@ impl PageServerConf {
match key {
"listen_pg_addr" => builder.listen_pg_addr(parse_toml_string(key, item)?),
"listen_http_addr" => builder.listen_http_addr(parse_toml_string(key, item)?),
"availability_zone" => builder.availability_zone(Some(parse_toml_string(key, item)?)),
"wait_lsn_timeout" => builder.wait_lsn_timeout(parse_toml_duration(key, item)?),
"wal_redo_timeout" => builder.wal_redo_timeout(parse_toml_duration(key, item)?),
"initial_superuser_name" => builder.superuser(parse_toml_string(key, item)?),
@@ -641,8 +612,7 @@ impl PageServerConf {
"auth_validation_public_key_path" => builder.auth_validation_public_key_path(Some(
PathBuf::from(parse_toml_string(key, item)?),
)),
"http_auth_type" => builder.http_auth_type(parse_toml_from_str(key, item)?),
"pg_auth_type" => builder.pg_auth_type(parse_toml_from_str(key, item)?),
"auth_type" => builder.auth_type(parse_toml_from_str(key, item)?),
"remote_storage" => {
builder.remote_storage_config(RemoteStorageConfig::from_toml(item)?)
}
@@ -677,7 +647,7 @@ impl PageServerConf {
let mut conf = builder.build().context("invalid config")?;
if conf.http_auth_type == AuthType::NeonJWT || conf.pg_auth_type == AuthType::NeonJWT {
if conf.auth_type == AuthType::NeonJWT {
let auth_validation_public_key_path = conf
.auth_validation_public_key_path
.get_or_insert_with(|| workdir.join("auth_public_key.pem"));
@@ -728,12 +698,6 @@ impl PageServerConf {
Some(parse_toml_u64("compaction_threshold", compaction_threshold)?.try_into()?);
}
if let Some(image_creation_threshold) = item.get("image_creation_threshold") {
t_conf.image_creation_threshold = Some(
parse_toml_u64("image_creation_threshold", image_creation_threshold)?.try_into()?,
);
}
if let Some(gc_horizon) = item.get("gc_horizon") {
t_conf.gc_horizon = Some(parse_toml_u64("gc_horizon", gc_horizon)?);
}
@@ -793,12 +757,10 @@ impl PageServerConf {
max_file_descriptors: defaults::DEFAULT_MAX_FILE_DESCRIPTORS,
listen_pg_addr: defaults::DEFAULT_PG_LISTEN_ADDR.to_string(),
listen_http_addr: defaults::DEFAULT_HTTP_LISTEN_ADDR.to_string(),
availability_zone: None,
superuser: "cloud_admin".to_string(),
workdir: repo_dir,
pg_distrib_dir,
http_auth_type: AuthType::Trust,
pg_auth_type: AuthType::Trust,
auth_type: AuthType::Trust,
auth_validation_public_key_path: None,
remote_storage_config: None,
default_tenant_conf: TenantConf::default(),
@@ -976,7 +938,6 @@ log_format = 'json'
id: NodeId(10),
listen_pg_addr: defaults::DEFAULT_PG_LISTEN_ADDR.to_string(),
listen_http_addr: defaults::DEFAULT_HTTP_LISTEN_ADDR.to_string(),
availability_zone: None,
wait_lsn_timeout: humantime::parse_duration(defaults::DEFAULT_WAIT_LSN_TIMEOUT)?,
wal_redo_timeout: humantime::parse_duration(defaults::DEFAULT_WAL_REDO_TIMEOUT)?,
superuser: defaults::DEFAULT_SUPERUSER.to_string(),
@@ -984,8 +945,7 @@ log_format = 'json'
max_file_descriptors: defaults::DEFAULT_MAX_FILE_DESCRIPTORS,
workdir,
pg_distrib_dir,
http_auth_type: AuthType::Trust,
pg_auth_type: AuthType::Trust,
auth_type: AuthType::Trust,
auth_validation_public_key_path: None,
remote_storage_config: None,
default_tenant_conf: TenantConf::default(),
@@ -1035,7 +995,6 @@ log_format = 'json'
id: NodeId(10),
listen_pg_addr: "127.0.0.1:64000".to_string(),
listen_http_addr: "127.0.0.1:9898".to_string(),
availability_zone: None,
wait_lsn_timeout: Duration::from_secs(111),
wal_redo_timeout: Duration::from_secs(111),
superuser: "zzzz".to_string(),
@@ -1043,8 +1002,7 @@ log_format = 'json'
max_file_descriptors: 333,
workdir,
pg_distrib_dir,
http_auth_type: AuthType::Trust,
pg_auth_type: AuthType::Trust,
auth_type: AuthType::Trust,
auth_validation_public_key_path: None,
remote_storage_config: None,
default_tenant_conf: TenantConf::default(),

View File

@@ -245,53 +245,6 @@ paths:
application/json:
schema:
$ref: "#/components/schemas/Error"
/v1/tenant/{tenant_id}/timeline/{timeline_id}/do_gc:
parameters:
- name: tenant_id
in: path
required: true
schema:
type: string
format: hex
- name: timeline_id
in: path
required: true
schema:
type: string
format: hex
put:
description: Garbage collect given timeline
responses:
"200":
description: OK
content:
application/json:
schema:
type: string
"400":
description: Error when no tenant id found in path, no timeline id or invalid timestamp
content:
application/json:
schema:
$ref: "#/components/schemas/Error"
"401":
description: Unauthorized Error
content:
application/json:
schema:
$ref: "#/components/schemas/UnauthorizedError"
"403":
description: Forbidden Error
content:
application/json:
schema:
$ref: "#/components/schemas/ForbiddenError"
"500":
description: Generic operation error
content:
application/json:
schema:
$ref: "#/components/schemas/Error"
/v1/tenant/{tenant_id}/attach:
parameters:
- name: tenant_id

View File

@@ -10,7 +10,6 @@ use remote_storage::GenericRemoteStorage;
use tenant_size_model::{SizeResult, StorageModel};
use tokio_util::sync::CancellationToken;
use tracing::*;
use utils::http::endpoint::RequestSpan;
use utils::http::request::{get_request_param, must_get_query_param, parse_query_param};
use super::models::{
@@ -21,7 +20,7 @@ use crate::context::{DownloadBehavior, RequestContext};
use crate::pgdatadir_mapping::LsnForTimestamp;
use crate::task_mgr::TaskKind;
use crate::tenant::config::TenantConfOpt;
use crate::tenant::mgr::{TenantMapInsertError, TenantStateError};
use crate::tenant::mgr::TenantMapInsertError;
use crate::tenant::size::ModelInputs;
use crate::tenant::storage_layer::LayerAccessStatsReset;
use crate::tenant::{PageReconstructError, Timeline};
@@ -82,52 +81,38 @@ fn get_config(request: &Request<Body>) -> &'static PageServerConf {
get_state(request).conf
}
/// Check that the requester is authorized to operate on given tenant
fn check_permission(request: &Request<Body>, tenant_id: Option<TenantId>) -> Result<(), ApiError> {
check_permission_with(request, |claims| {
crate::auth::check_permission(claims, tenant_id)
})
}
impl From<PageReconstructError> for ApiError {
fn from(pre: PageReconstructError) -> ApiError {
match pre {
PageReconstructError::Other(pre) => ApiError::InternalServerError(pre),
PageReconstructError::NeedsDownload(_, _) => {
// This shouldn't happen, because we use a RequestContext that requests to
// download any missing layer files on-demand.
ApiError::InternalServerError(anyhow::anyhow!("need to download remote layer file"))
}
PageReconstructError::Cancelled => {
ApiError::InternalServerError(anyhow::anyhow!("request was cancelled"))
}
PageReconstructError::WalRedo(pre) => {
ApiError::InternalServerError(anyhow::Error::new(pre))
}
fn apierror_from_prerror(err: PageReconstructError) -> ApiError {
match err {
PageReconstructError::Other(err) => ApiError::InternalServerError(err),
PageReconstructError::NeedsDownload(_, _) => {
// This shouldn't happen, because we use a RequestContext that requests to
// download any missing layer files on-demand.
ApiError::InternalServerError(anyhow::anyhow!("need to download remote layer file"))
}
PageReconstructError::Cancelled => {
ApiError::InternalServerError(anyhow::anyhow!("request was cancelled"))
}
PageReconstructError::WalRedo(err) => {
ApiError::InternalServerError(anyhow::Error::new(err))
}
}
}
impl From<TenantMapInsertError> for ApiError {
fn from(tmie: TenantMapInsertError) -> ApiError {
match tmie {
TenantMapInsertError::StillInitializing | TenantMapInsertError::ShuttingDown => {
ApiError::InternalServerError(anyhow::Error::new(tmie))
}
TenantMapInsertError::TenantAlreadyExists(id, state) => {
ApiError::Conflict(format!("tenant {id} already exists, state: {state:?}"))
}
TenantMapInsertError::Closure(e) => ApiError::InternalServerError(e),
fn apierror_from_tenant_map_insert_error(e: TenantMapInsertError) -> ApiError {
match e {
TenantMapInsertError::StillInitializing | TenantMapInsertError::ShuttingDown => {
ApiError::InternalServerError(anyhow::Error::new(e))
}
}
}
impl From<TenantStateError> for ApiError {
fn from(tse: TenantStateError) -> ApiError {
match tse {
TenantStateError::NotFound(tid) => ApiError::NotFound(anyhow!("tenant {}", tid)),
_ => ApiError::InternalServerError(anyhow::Error::new(tse)),
TenantMapInsertError::TenantAlreadyExists(id, state) => {
ApiError::Conflict(format!("tenant {id} already exists, state: {state:?}"))
}
TenantMapInsertError::Closure(e) => ApiError::InternalServerError(e),
}
}
@@ -231,7 +216,9 @@ async fn timeline_create_handler(mut request: Request<Body>) -> Result<Response<
let ctx = RequestContext::new(TaskKind::MgmtRequest, DownloadBehavior::Error);
let tenant = mgr::get_tenant(tenant_id, true).await?;
let tenant = mgr::get_tenant(tenant_id, true)
.await
.map_err(ApiError::NotFound)?;
match tenant.create_timeline(
new_timeline_id,
request_data.ancestor_timeline_id.map(TimelineId::from),
@@ -261,7 +248,9 @@ async fn timeline_list_handler(request: Request<Body>) -> Result<Response<Body>,
let ctx = RequestContext::new(TaskKind::MgmtRequest, DownloadBehavior::Download);
let response_data = async {
let tenant = mgr::get_tenant(tenant_id, true).await?;
let tenant = mgr::get_tenant(tenant_id, true)
.await
.map_err(ApiError::NotFound)?;
let timelines = tenant.list_timelines();
let mut response_data = Vec::with_capacity(timelines.len());
@@ -277,7 +266,7 @@ async fn timeline_list_handler(request: Request<Body>) -> Result<Response<Body>,
response_data.push(timeline_info);
}
Ok::<Vec<TimelineInfo>, ApiError>(response_data)
Ok(response_data)
}
.instrument(info_span!("timeline_list", tenant = %tenant_id))
.await?;
@@ -296,7 +285,9 @@ async fn timeline_detail_handler(request: Request<Body>) -> Result<Response<Body
let ctx = RequestContext::new(TaskKind::MgmtRequest, DownloadBehavior::Download);
let timeline_info = async {
let tenant = mgr::get_tenant(tenant_id, true).await?;
let tenant = mgr::get_tenant(tenant_id, true)
.await
.map_err(ApiError::NotFound)?;
let timeline = tenant
.get_timeline(timeline_id, false)
@@ -332,7 +323,10 @@ async fn get_lsn_by_timestamp_handler(request: Request<Body>) -> Result<Response
let ctx = RequestContext::new(TaskKind::MgmtRequest, DownloadBehavior::Download);
let timeline = active_timeline_of_active_tenant(tenant_id, timeline_id).await?;
let result = timeline.find_lsn_for_timestamp(timestamp_pg, &ctx).await?;
let result = timeline
.find_lsn_for_timestamp(timestamp_pg, &ctx)
.await
.map_err(apierror_from_prerror)?;
let result = match result {
LsnForTimestamp::Present(lsn) => format!("{lsn}"),
@@ -357,7 +351,8 @@ async fn tenant_attach_handler(request: Request<Body>) -> Result<Response<Body>,
if let Some(remote_storage) = &state.remote_storage {
mgr::attach_tenant(state.conf, tenant_id, remote_storage.clone(), &ctx)
.instrument(info_span!("tenant_attach", tenant = %tenant_id))
.await?;
.await
.map_err(apierror_from_tenant_map_insert_error)?;
} else {
return Err(ApiError::BadRequest(anyhow!(
"attach_tenant is not possible because pageserver was configured without remote storage"
@@ -376,7 +371,11 @@ async fn timeline_delete_handler(request: Request<Body>) -> Result<Response<Body
mgr::delete_timeline(tenant_id, timeline_id, &ctx)
.instrument(info_span!("timeline_delete", tenant = %tenant_id, timeline = %timeline_id))
.await?;
.await
// FIXME: Errors from `delete_timeline` can occur for a number of reasons, incuding both
// user and internal errors. Replace this with better handling once the error type permits
// it.
.map_err(ApiError::InternalServerError)?;
json_response(StatusCode::OK, ())
}
@@ -389,7 +388,10 @@ async fn tenant_detach_handler(request: Request<Body>) -> Result<Response<Body>,
let conf = state.conf;
mgr::detach_tenant(conf, tenant_id)
.instrument(info_span!("tenant_detach", tenant = %tenant_id))
.await?;
.await
// FIXME: Errors from `detach_tenant` can be caused by both both user and internal errors.
// Replace this with better handling once the error type permits it.
.map_err(ApiError::InternalServerError)?;
json_response(StatusCode::OK, ())
}
@@ -403,7 +405,8 @@ async fn tenant_load_handler(request: Request<Body>) -> Result<Response<Body>, A
let state = get_state(&request);
mgr::load_tenant(state.conf, tenant_id, state.remote_storage.clone(), &ctx)
.instrument(info_span!("load", tenant = %tenant_id))
.await?;
.await
.map_err(apierror_from_tenant_map_insert_error)?;
json_response(StatusCode::ACCEPTED, ())
}
@@ -416,7 +419,10 @@ async fn tenant_ignore_handler(request: Request<Body>) -> Result<Response<Body>,
let conf = state.conf;
mgr::ignore_tenant(conf, tenant_id)
.instrument(info_span!("ignore_tenant", tenant = %tenant_id))
.await?;
.await
// FIXME: Errors from `ignore_tenant` can be caused by both both user and internal errors.
// Replace this with better handling once the error type permits it.
.map_err(ApiError::InternalServerError)?;
json_response(StatusCode::OK, ())
}
@@ -490,7 +496,9 @@ async fn tenant_size_handler(request: Request<Body>) -> Result<Response<Body>, A
let headers = request.headers();
let ctx = RequestContext::new(TaskKind::MgmtRequest, DownloadBehavior::Download);
let tenant = mgr::get_tenant(tenant_id, true).await?;
let tenant = mgr::get_tenant(tenant_id, true)
.await
.map_err(ApiError::InternalServerError)?;
// this can be long operation
let inputs = tenant
@@ -753,7 +761,8 @@ async fn tenant_create_handler(mut request: Request<Body>) -> Result<Response<Bo
&ctx,
)
.instrument(info_span!("tenant_create", tenant = ?target_tenant_id))
.await?;
.await
.map_err(apierror_from_tenant_map_insert_error)?;
// We created the tenant. Existing API semantics are that the tenant
// is Active when this function returns.
@@ -777,7 +786,9 @@ async fn get_tenant_config_handler(request: Request<Body>) -> Result<Response<Bo
let tenant_id: TenantId = parse_request_param(&request, "tenant_id")?;
check_permission(&request, Some(tenant_id))?;
let tenant = mgr::get_tenant(tenant_id, false).await?;
let tenant = mgr::get_tenant(tenant_id, false)
.await
.map_err(ApiError::NotFound)?;
let response = HashMap::from([
(
@@ -872,7 +883,10 @@ async fn update_tenant_config_handler(
let state = get_state(&request);
mgr::set_new_tenant_config(state.conf, tenant_conf, tenant_id)
.instrument(info_span!("tenant_config", tenant = ?tenant_id))
.await?;
.await
// FIXME: `update_tenant_config` can fail because of both user and internal errors.
// Replace this `map_err` with better error handling once the type permits it
.map_err(ApiError::InternalServerError)?;
json_response(StatusCode::OK, ())
}
@@ -1009,7 +1023,9 @@ async fn active_timeline_of_active_tenant(
tenant_id: TenantId,
timeline_id: TimelineId,
) -> Result<Arc<Timeline>, ApiError> {
let tenant = mgr::get_tenant(tenant_id, true).await?;
let tenant = mgr::get_tenant(tenant_id, true)
.await
.map_err(ApiError::NotFound)?;
tenant
.get_timeline(timeline_id, true)
.map_err(ApiError::NotFound)
@@ -1075,8 +1091,7 @@ pub fn make_router(
let handler = $handler;
#[cfg(not(feature = "testing"))]
let handler = cfg_disabled;
move |r| RequestSpan(handler).handle(r)
handler
}};
}
@@ -1084,55 +1099,35 @@ pub fn make_router(
.data(Arc::new(
State::new(conf, auth, remote_storage).context("Failed to initialize router state")?,
))
.get("/v1/status", |r| RequestSpan(status_handler).handle(r))
.get("/v1/status", status_handler)
.put(
"/v1/failpoints",
testing_api!("manage failpoints", failpoints_handler),
)
.get("/v1/tenant", |r| RequestSpan(tenant_list_handler).handle(r))
.post("/v1/tenant", |r| {
RequestSpan(tenant_create_handler).handle(r)
})
.get("/v1/tenant/:tenant_id", |r| {
RequestSpan(tenant_status).handle(r)
})
.get("/v1/tenant/:tenant_id/synthetic_size", |r| {
RequestSpan(tenant_size_handler).handle(r)
})
.put("/v1/tenant/config", |r| {
RequestSpan(update_tenant_config_handler).handle(r)
})
.get("/v1/tenant/:tenant_id/config", |r| {
RequestSpan(get_tenant_config_handler).handle(r)
})
.get("/v1/tenant/:tenant_id/timeline", |r| {
RequestSpan(timeline_list_handler).handle(r)
})
.post("/v1/tenant/:tenant_id/timeline", |r| {
RequestSpan(timeline_create_handler).handle(r)
})
.post("/v1/tenant/:tenant_id/attach", |r| {
RequestSpan(tenant_attach_handler).handle(r)
})
.post("/v1/tenant/:tenant_id/detach", |r| {
RequestSpan(tenant_detach_handler).handle(r)
})
.post("/v1/tenant/:tenant_id/load", |r| {
RequestSpan(tenant_load_handler).handle(r)
})
.post("/v1/tenant/:tenant_id/ignore", |r| {
RequestSpan(tenant_ignore_handler).handle(r)
})
.get("/v1/tenant/:tenant_id/timeline/:timeline_id", |r| {
RequestSpan(timeline_detail_handler).handle(r)
})
.get("/v1/tenant", tenant_list_handler)
.post("/v1/tenant", tenant_create_handler)
.get("/v1/tenant/:tenant_id", tenant_status)
.get("/v1/tenant/:tenant_id/synthetic_size", tenant_size_handler)
.put("/v1/tenant/config", update_tenant_config_handler)
.get("/v1/tenant/:tenant_id/config", get_tenant_config_handler)
.get("/v1/tenant/:tenant_id/timeline", timeline_list_handler)
.post("/v1/tenant/:tenant_id/timeline", timeline_create_handler)
.post("/v1/tenant/:tenant_id/attach", tenant_attach_handler)
.post("/v1/tenant/:tenant_id/detach", tenant_detach_handler)
.post("/v1/tenant/:tenant_id/load", tenant_load_handler)
.post("/v1/tenant/:tenant_id/ignore", tenant_ignore_handler)
.get(
"/v1/tenant/:tenant_id/timeline/:timeline_id",
timeline_detail_handler,
)
.get(
"/v1/tenant/:tenant_id/timeline/:timeline_id/get_lsn_by_timestamp",
|r| RequestSpan(get_lsn_by_timestamp_handler).handle(r),
get_lsn_by_timestamp_handler,
)
.put(
"/v1/tenant/:tenant_id/timeline/:timeline_id/do_gc",
timeline_gc_handler,
)
.put("/v1/tenant/:tenant_id/timeline/:timeline_id/do_gc", |r| {
RequestSpan(timeline_gc_handler).handle(r)
})
.put(
"/v1/tenant/:tenant_id/timeline/:timeline_id/compact",
testing_api!("run timeline compaction", timeline_compact_handler),
@@ -1143,26 +1138,28 @@ pub fn make_router(
)
.post(
"/v1/tenant/:tenant_id/timeline/:timeline_id/download_remote_layers",
|r| RequestSpan(timeline_download_remote_layers_handler_post).handle(r),
timeline_download_remote_layers_handler_post,
)
.get(
"/v1/tenant/:tenant_id/timeline/:timeline_id/download_remote_layers",
|r| RequestSpan(timeline_download_remote_layers_handler_get).handle(r),
timeline_download_remote_layers_handler_get,
)
.delete(
"/v1/tenant/:tenant_id/timeline/:timeline_id",
timeline_delete_handler,
)
.get(
"/v1/tenant/:tenant_id/timeline/:timeline_id/layer",
layer_map_info_handler,
)
.delete("/v1/tenant/:tenant_id/timeline/:timeline_id", |r| {
RequestSpan(timeline_delete_handler).handle(r)
})
.get("/v1/tenant/:tenant_id/timeline/:timeline_id/layer", |r| {
RequestSpan(layer_map_info_handler).handle(r)
})
.get(
"/v1/tenant/:tenant_id/timeline/:timeline_id/layer/:layer_file_name",
|r| RequestSpan(layer_download_handler).handle(r),
layer_download_handler,
)
.delete(
"/v1/tenant/:tenant_id/timeline/:timeline_id/layer/:layer_file_name",
|r| RequestSpan(evict_timeline_layer_handler).handle(r),
evict_timeline_layer_handler,
)
.get("/v1/panic", |r| RequestSpan(always_panic_handler).handle(r))
.get("/v1/panic", always_panic_handler)
.any(handler_404))
}

View File

@@ -123,22 +123,6 @@ static REMOTE_PHYSICAL_SIZE: Lazy<UIntGaugeVec> = Lazy::new(|| {
.expect("failed to define a metric")
});
pub static REMOTE_ONDEMAND_DOWNLOADED_LAYERS: Lazy<IntCounter> = Lazy::new(|| {
register_int_counter!(
"pageserver_remote_ondemand_downloaded_layers_total",
"Total on-demand downloaded layers"
)
.unwrap()
});
pub static REMOTE_ONDEMAND_DOWNLOADED_BYTES: Lazy<IntCounter> = Lazy::new(|| {
register_int_counter!(
"pageserver_remote_ondemand_downloaded_bytes_total",
"Total bytes of layers on-demand downloaded",
)
.unwrap()
});
static CURRENT_LOGICAL_SIZE: Lazy<UIntGaugeVec> = Lazy::new(|| {
register_uint_gauge_vec!(
"pageserver_current_logical_size",

View File

@@ -12,7 +12,7 @@
use anyhow::Context;
use bytes::Buf;
use bytes::Bytes;
use futures::Stream;
use futures::{Stream, StreamExt};
use pageserver_api::models::TenantState;
use pageserver_api::models::{
PagestreamBeMessage, PagestreamDbSizeRequest, PagestreamDbSizeResponse,
@@ -20,9 +20,7 @@ use pageserver_api::models::{
PagestreamFeMessage, PagestreamGetPageRequest, PagestreamGetPageResponse,
PagestreamNblocksRequest, PagestreamNblocksResponse,
};
use postgres_backend::PostgresBackendTCP;
use postgres_backend::{self, is_expected_io_error, AuthType, PostgresBackend, QueryError};
use pq_proto::framed::ConnectionError;
use pq_proto::ConnectionError;
use pq_proto::FeStartupPacket;
use pq_proto::{BeMessage, FeMessage, RowDescriptor};
use std::io;
@@ -31,13 +29,14 @@ use std::str;
use std::str::FromStr;
use std::sync::Arc;
use std::time::Duration;
use tokio_util::io::StreamReader;
use tracing::*;
use utils::id::ConnectionId;
use utils::{
auth::{Claims, JwtAuth, Scope},
id::{TenantId, TimelineId},
lsn::Lsn,
postgres_backend::AuthType,
postgres_backend_async::{self, is_expected_io_error, PostgresBackend, QueryError},
simple_rcu::RcuReadGuard,
};
@@ -56,7 +55,7 @@ use crate::trace::Tracer;
use postgres_ffi::pg_constants::DEFAULTTABLESPACE_OID;
use postgres_ffi::BLCKSZ;
fn copyin_stream(pgb: &mut PostgresBackendTCP) -> impl Stream<Item = io::Result<Bytes>> + '_ {
fn copyin_stream(pgb: &mut PostgresBackend) -> impl Stream<Item = io::Result<Bytes>> + '_ {
async_stream::try_stream! {
loop {
let msg = tokio::select! {
@@ -65,11 +64,11 @@ fn copyin_stream(pgb: &mut PostgresBackendTCP) -> impl Stream<Item = io::Result<
_ = task_mgr::shutdown_watcher() => {
// We were requested to shut down.
let msg = format!("pageserver is shutting down");
let _ = pgb.write_message_noflush(&BeMessage::ErrorResponse(&msg, None));
let _ = pgb.write_message(&BeMessage::ErrorResponse(&msg, None));
Err(QueryError::Other(anyhow::anyhow!(msg)))
}
msg = pgb.read_message() => { msg.map_err(QueryError::from)}
msg = pgb.read_message() => { msg }
};
match msg {
@@ -80,16 +79,14 @@ fn copyin_stream(pgb: &mut PostgresBackendTCP) -> impl Stream<Item = io::Result<
FeMessage::Sync => continue,
FeMessage::Terminate => {
let msg = "client terminated connection with Terminate message during COPY";
let query_error = QueryError::Disconnected(ConnectionError::Io(io::Error::new(io::ErrorKind::ConnectionReset, msg)));
// error can't happen here, ErrorResponse serialization should be always ok
pgb.write_message_noflush(&BeMessage::ErrorResponse(msg, Some(query_error.pg_error_code()))).map_err(|e| e.into_io_error())?;
let query_error_error = QueryError::Disconnected(ConnectionError::Socket(io::Error::new(io::ErrorKind::ConnectionReset, msg)));
pgb.write_message(&BeMessage::ErrorResponse(msg, Some(query_error_error.pg_error_code())))?;
Err(io::Error::new(io::ErrorKind::ConnectionReset, msg))?;
break;
}
m => {
let msg = format!("unexpected message {m:?}");
// error can't happen here, ErrorResponse serialization should be always ok
pgb.write_message_noflush(&BeMessage::ErrorResponse(&msg, None)).map_err(|e| e.into_io_error())?;
pgb.write_message(&BeMessage::ErrorResponse(&msg, None))?;
Err(io::Error::new(io::ErrorKind::Other, msg))?;
break;
}
@@ -99,66 +96,22 @@ fn copyin_stream(pgb: &mut PostgresBackendTCP) -> impl Stream<Item = io::Result<
}
Ok(None) => {
let msg = "client closed connection during COPY";
let query_error = QueryError::Disconnected(ConnectionError::Io(io::Error::new(io::ErrorKind::ConnectionReset, msg)));
// error can't happen here, ErrorResponse serialization should be always ok
pgb.write_message_noflush(&BeMessage::ErrorResponse(msg, Some(query_error.pg_error_code()))).map_err(|e| e.into_io_error())?;
let query_error_error = QueryError::Disconnected(ConnectionError::Socket(io::Error::new(io::ErrorKind::ConnectionReset, msg)));
pgb.write_message(&BeMessage::ErrorResponse(msg, Some(query_error_error.pg_error_code())))?;
pgb.flush().await?;
Err(io::Error::new(io::ErrorKind::ConnectionReset, msg))?;
}
Err(QueryError::Disconnected(ConnectionError::Io(io_error))) => {
Err(QueryError::Disconnected(ConnectionError::Socket(io_error))) => {
Err(io_error)?;
}
Err(other) => {
Err(io::Error::new(io::ErrorKind::Other, other.to_string()))?;
Err(io::Error::new(io::ErrorKind::Other, other))?;
}
};
}
}
}
/// Read the end of a tar archive.
///
/// A tar archive normally ends with two consecutive blocks of zeros, 512 bytes each.
/// `tokio_tar` already read the first such block. Read the second all-zeros block,
/// and check that there is no more data after the EOF marker.
///
/// XXX: Currently, any trailing data after the EOF marker prints a warning.
/// Perhaps it should be a hard error?
async fn read_tar_eof(mut reader: (impl tokio::io::AsyncRead + Unpin)) -> anyhow::Result<()> {
use tokio::io::AsyncReadExt;
let mut buf = [0u8; 512];
// Read the all-zeros block, and verify it
let mut total_bytes = 0;
while total_bytes < 512 {
let nbytes = reader.read(&mut buf[total_bytes..]).await?;
total_bytes += nbytes;
if nbytes == 0 {
break;
}
}
if total_bytes < 512 {
anyhow::bail!("incomplete or invalid tar EOF marker");
}
if !buf.iter().all(|&x| x == 0) {
anyhow::bail!("invalid tar EOF marker");
}
// Drain any data after the EOF marker
let mut trailing_bytes = 0;
loop {
let nbytes = reader.read(&mut buf).await?;
trailing_bytes += nbytes;
if nbytes == 0 {
break;
}
}
if trailing_bytes > 0 {
warn!("ignored {trailing_bytes} unexpected bytes after the tar archive");
}
Ok(())
}
///////////////////////////////////////////////////////////////////////////////
///
@@ -259,7 +212,7 @@ async fn page_service_conn_main(
// we've been requested to shut down
Ok(())
}
Err(QueryError::Disconnected(ConnectionError::Io(io_error))) => {
Err(QueryError::Disconnected(ConnectionError::Socket(io_error))) => {
if is_expected_io_error(&io_error) {
info!("Postgres client disconnected ({io_error})");
Ok(())
@@ -333,7 +286,7 @@ impl PageServerHandler {
#[instrument(skip(self, pgb, ctx))]
async fn handle_pagerequests(
&self,
pgb: &mut PostgresBackendTCP,
pgb: &mut PostgresBackend,
tenant_id: TenantId,
timeline_id: TimelineId,
ctx: RequestContext,
@@ -358,7 +311,7 @@ impl PageServerHandler {
let timeline = tenant.get_timeline(timeline_id, true)?;
// switch client to COPYBOTH
pgb.write_message_noflush(&BeMessage::CopyBothResponse)?;
pgb.write_message(&BeMessage::CopyBothResponse)?;
pgb.flush().await?;
let metrics = PageRequestMetrics::new(&tenant_id, &timeline_id);
@@ -427,7 +380,7 @@ impl PageServerHandler {
})
});
pgb.write_message_noflush(&BeMessage::CopyData(&response.serialize()))?;
pgb.write_message(&BeMessage::CopyData(&response.serialize()))?;
pgb.flush().await?;
}
Ok(())
@@ -437,7 +390,7 @@ impl PageServerHandler {
#[instrument(skip(self, pgb, ctx))]
async fn handle_import_basebackup(
&self,
pgb: &mut PostgresBackendTCP,
pgb: &mut PostgresBackend,
tenant_id: TenantId,
timeline_id: TimelineId,
base_lsn: Lsn,
@@ -463,17 +416,22 @@ impl PageServerHandler {
// Import basebackup provided via CopyData
info!("importing basebackup");
pgb.write_message_noflush(&BeMessage::CopyInResponse)?;
pgb.write_message(&BeMessage::CopyInResponse)?;
pgb.flush().await?;
let copyin_reader = StreamReader::new(copyin_stream(pgb));
tokio::pin!(copyin_reader);
let mut copyin_stream = Box::pin(copyin_stream(pgb));
timeline
.import_basebackup_from_tar(&mut copyin_reader, base_lsn, &ctx)
.import_basebackup_from_tar(&mut copyin_stream, base_lsn, &ctx)
.await?;
// Read the end of the tar archive.
read_tar_eof(copyin_reader).await?;
// Drain the rest of the Copy data
let mut bytes_after_tar = 0;
while let Some(bytes) = copyin_stream.next().await {
bytes_after_tar += bytes?.len();
}
if bytes_after_tar > 0 {
warn!("ignored {bytes_after_tar} unexpected bytes after the tar archive");
}
// TODO check checksum
// Meanwhile you can verify client-side by taking fullbackup
@@ -488,7 +446,7 @@ impl PageServerHandler {
#[instrument(skip(self, pgb, ctx))]
async fn handle_import_wal(
&self,
pgb: &mut PostgresBackendTCP,
pgb: &mut PostgresBackend,
tenant_id: TenantId,
timeline_id: TimelineId,
start_lsn: Lsn,
@@ -510,15 +468,21 @@ impl PageServerHandler {
// Import wal provided via CopyData
info!("importing wal");
pgb.write_message_noflush(&BeMessage::CopyInResponse)?;
pgb.write_message(&BeMessage::CopyInResponse)?;
pgb.flush().await?;
let copyin_reader = StreamReader::new(copyin_stream(pgb));
tokio::pin!(copyin_reader);
import_wal_from_tar(&timeline, &mut copyin_reader, start_lsn, end_lsn, &ctx).await?;
let mut copyin_stream = Box::pin(copyin_stream(pgb));
let mut reader = tokio_util::io::StreamReader::new(&mut copyin_stream);
import_wal_from_tar(&timeline, &mut reader, start_lsn, end_lsn, &ctx).await?;
info!("wal import complete");
// Read the end of the tar archive.
read_tar_eof(copyin_reader).await?;
// Drain the rest of the Copy data
let mut bytes_after_tar = 0;
while let Some(bytes) = copyin_stream.next().await {
bytes_after_tar += bytes?.len();
}
if bytes_after_tar > 0 {
warn!("ignored {bytes_after_tar} unexpected bytes after the tar archive");
}
// TODO Does it make sense to overshoot?
if timeline.get_last_record_lsn() < end_lsn {
@@ -693,7 +657,7 @@ impl PageServerHandler {
#[instrument(skip(self, pgb, ctx))]
async fn handle_basebackup_request(
&mut self,
pgb: &mut PostgresBackendTCP,
pgb: &mut PostgresBackend,
tenant_id: TenantId,
timeline_id: TimelineId,
lsn: Option<Lsn>,
@@ -714,7 +678,7 @@ impl PageServerHandler {
}
// switch client to COPYOUT
pgb.write_message_noflush(&BeMessage::CopyOutResponse)?;
pgb.write_message(&BeMessage::CopyOutResponse)?;
pgb.flush().await?;
// Send a tarball of the latest layer on the timeline
@@ -731,7 +695,7 @@ impl PageServerHandler {
.await?;
}
pgb.write_message_noflush(&BeMessage::CopyDone)?;
pgb.write_message(&BeMessage::CopyDone)?;
pgb.flush().await?;
info!("basebackup complete");
@@ -757,10 +721,10 @@ impl PageServerHandler {
}
#[async_trait::async_trait]
impl postgres_backend::Handler<tokio::net::TcpStream> for PageServerHandler {
impl postgres_backend_async::Handler for PageServerHandler {
fn check_auth_jwt(
&mut self,
_pgb: &mut PostgresBackendTCP,
_pgb: &mut PostgresBackend,
jwt_response: &[u8],
) -> Result<(), QueryError> {
// this unwrap is never triggered, because check_auth_jwt only called when auth_type is NeonJWT
@@ -788,7 +752,7 @@ impl postgres_backend::Handler<tokio::net::TcpStream> for PageServerHandler {
fn startup(
&mut self,
_pgb: &mut PostgresBackendTCP,
_pgb: &mut PostgresBackend,
_sm: &FeStartupPacket,
) -> Result<(), QueryError> {
Ok(())
@@ -796,7 +760,7 @@ impl postgres_backend::Handler<tokio::net::TcpStream> for PageServerHandler {
async fn process_query(
&mut self,
pgb: &mut PostgresBackendTCP,
pgb: &mut PostgresBackend,
query_string: &str,
) -> Result<(), QueryError> {
let ctx = self.connection_ctx.attached_child();
@@ -848,7 +812,7 @@ impl postgres_backend::Handler<tokio::net::TcpStream> for PageServerHandler {
// Check that the timeline exists
self.handle_basebackup_request(pgb, tenant_id, timeline_id, lsn, None, false, ctx)
.await?;
pgb.write_message_noflush(&BeMessage::CommandComplete(b"SELECT 1"))?;
pgb.write_message(&BeMessage::CommandComplete(b"SELECT 1"))?;
}
// return pair of prev_lsn and last_lsn
else if query_string.starts_with("get_last_record_rlsn ") {
@@ -871,15 +835,15 @@ impl postgres_backend::Handler<tokio::net::TcpStream> for PageServerHandler {
let end_of_timeline = timeline.get_last_record_rlsn();
pgb.write_message_noflush(&BeMessage::RowDescription(&[
pgb.write_message(&BeMessage::RowDescription(&[
RowDescriptor::text_col(b"prev_lsn"),
RowDescriptor::text_col(b"last_lsn"),
]))?
.write_message_noflush(&BeMessage::DataRow(&[
.write_message(&BeMessage::DataRow(&[
Some(end_of_timeline.prev.to_string().as_bytes()),
Some(end_of_timeline.last.to_string().as_bytes()),
]))?
.write_message_noflush(&BeMessage::CommandComplete(b"SELECT 1"))?;
.write_message(&BeMessage::CommandComplete(b"SELECT 1"))?;
}
// same as basebackup, but result includes relational data as well
else if query_string.starts_with("fullbackup ") {
@@ -920,7 +884,7 @@ impl postgres_backend::Handler<tokio::net::TcpStream> for PageServerHandler {
// Check that the timeline exists
self.handle_basebackup_request(pgb, tenant_id, timeline_id, lsn, prev_lsn, true, ctx)
.await?;
pgb.write_message_noflush(&BeMessage::CommandComplete(b"SELECT 1"))?;
pgb.write_message(&BeMessage::CommandComplete(b"SELECT 1"))?;
} else if query_string.starts_with("import basebackup ") {
// Import the `base` section (everything but the wal) of a basebackup.
// Assumes the tenant already exists on this pageserver.
@@ -965,10 +929,10 @@ impl postgres_backend::Handler<tokio::net::TcpStream> for PageServerHandler {
)
.await
{
Ok(()) => pgb.write_message_noflush(&BeMessage::CommandComplete(b"SELECT 1"))?,
Ok(()) => pgb.write_message(&BeMessage::CommandComplete(b"SELECT 1"))?,
Err(e) => {
error!("error importing base backup between {base_lsn} and {end_lsn}: {e:?}");
pgb.write_message_noflush(&BeMessage::ErrorResponse(
pgb.write_message(&BeMessage::ErrorResponse(
&e.to_string(),
Some(e.pg_error_code()),
))?
@@ -1001,10 +965,10 @@ impl postgres_backend::Handler<tokio::net::TcpStream> for PageServerHandler {
.handle_import_wal(pgb, tenant_id, timeline_id, start_lsn, end_lsn, ctx)
.await
{
Ok(()) => pgb.write_message_noflush(&BeMessage::CommandComplete(b"SELECT 1"))?,
Ok(()) => pgb.write_message(&BeMessage::CommandComplete(b"SELECT 1"))?,
Err(e) => {
error!("error importing WAL between {start_lsn} and {end_lsn}: {e:?}");
pgb.write_message_noflush(&BeMessage::ErrorResponse(
pgb.write_message(&BeMessage::ErrorResponse(
&e.to_string(),
Some(e.pg_error_code()),
))?
@@ -1013,7 +977,7 @@ impl postgres_backend::Handler<tokio::net::TcpStream> for PageServerHandler {
} else if query_string.to_ascii_lowercase().starts_with("set ") {
// important because psycopg2 executes "SET datestyle TO 'ISO'"
// on connect
pgb.write_message_noflush(&BeMessage::CommandComplete(b"SELECT 1"))?;
pgb.write_message(&BeMessage::CommandComplete(b"SELECT 1"))?;
} else if query_string.starts_with("show ") {
// show <tenant_id>
let (_, params_raw) = query_string.split_at("show ".len());
@@ -1029,7 +993,7 @@ impl postgres_backend::Handler<tokio::net::TcpStream> for PageServerHandler {
self.check_permission(Some(tenant_id))?;
let tenant = get_active_tenant_with_timeout(tenant_id, &ctx).await?;
pgb.write_message_noflush(&BeMessage::RowDescription(&[
pgb.write_message(&BeMessage::RowDescription(&[
RowDescriptor::int8_col(b"checkpoint_distance"),
RowDescriptor::int8_col(b"checkpoint_timeout"),
RowDescriptor::int8_col(b"compaction_target_size"),
@@ -1040,7 +1004,7 @@ impl postgres_backend::Handler<tokio::net::TcpStream> for PageServerHandler {
RowDescriptor::int8_col(b"image_creation_threshold"),
RowDescriptor::int8_col(b"pitr_interval"),
]))?
.write_message_noflush(&BeMessage::DataRow(&[
.write_message(&BeMessage::DataRow(&[
Some(tenant.get_checkpoint_distance().to_string().as_bytes()),
Some(
tenant
@@ -1063,7 +1027,7 @@ impl postgres_backend::Handler<tokio::net::TcpStream> for PageServerHandler {
Some(tenant.get_image_creation_threshold().to_string().as_bytes()),
Some(tenant.get_pitr_interval().as_secs().to_string().as_bytes()),
]))?
.write_message_noflush(&BeMessage::CommandComplete(b"SELECT 1"))?;
.write_message(&BeMessage::CommandComplete(b"SELECT 1"))?;
} else {
return Err(QueryError::Other(anyhow::anyhow!(
"unknown command {query_string}"
@@ -1091,7 +1055,7 @@ impl From<GetActiveTenantError> for QueryError {
fn from(e: GetActiveTenantError) -> Self {
match e {
GetActiveTenantError::WaitForActiveTimeout { .. } => QueryError::Disconnected(
ConnectionError::Io(io::Error::new(io::ErrorKind::TimedOut, e.to_string())),
ConnectionError::Socket(io::Error::new(io::ErrorKind::TimedOut, e.to_string())),
),
GetActiveTenantError::Other(e) => QueryError::Other(e),
}
@@ -1107,10 +1071,7 @@ async fn get_active_tenant_with_timeout(
tenant_id: TenantId,
_ctx: &RequestContext, /* require get a context to support cancellation in the future */
) -> Result<Arc<Tenant>, GetActiveTenantError> {
let tenant = match mgr::get_tenant(tenant_id, false).await {
Ok(tenant) => tenant,
Err(e) => return Err(GetActiveTenantError::Other(e.into())),
};
let tenant = mgr::get_tenant(tenant_id, false).await?;
let wait_time = Duration::from_secs(30);
match tokio::time::timeout(wait_time, tenant.wait_to_become_active()).await {
Ok(Ok(())) => Ok(tenant),

View File

@@ -12,7 +12,9 @@
//!
use anyhow::{bail, Context};
use bytes::Bytes;
use futures::FutureExt;
use futures::Stream;
use pageserver_api::models::TimelineState;
use remote_storage::DownloadError;
use remote_storage::GenericRemoteStorage;
@@ -237,13 +239,14 @@ impl UninitializedTimeline<'_> {
/// Prepares timeline data by loading it from the basebackup archive.
pub async fn import_basebackup_from_tar(
self,
copyin_read: &mut (impl tokio::io::AsyncRead + Send + Sync + Unpin),
copyin_stream: &mut (impl Stream<Item = io::Result<Bytes>> + Sync + Send + Unpin),
base_lsn: Lsn,
ctx: &RequestContext,
) -> anyhow::Result<Arc<Timeline>> {
let raw_timeline = self.raw_timeline()?;
import_datadir::import_basebackup_from_tar(raw_timeline, copyin_read, base_lsn, ctx)
let mut reader = tokio_util::io::StreamReader::new(copyin_stream);
import_datadir::import_basebackup_from_tar(raw_timeline, &mut reader, base_lsn, ctx)
.await
.context("Failed to import basebackup")?;
@@ -1240,8 +1243,11 @@ impl Tenant {
"Cannot run GC iteration on inactive tenant"
);
self.gc_iteration_internal(target_timeline_id, horizon, pitr, ctx)
.await
let gc_result = self
.gc_iteration_internal(target_timeline_id, horizon, pitr, ctx)
.await;
gc_result
}
/// Perform one compaction iteration.
@@ -3170,44 +3176,6 @@ mod tests {
}
*/
#[tokio::test]
async fn test_get_branchpoints_from_an_inactive_timeline() -> anyhow::Result<()> {
let (tenant, ctx) =
TenantHarness::create("test_get_branchpoints_from_an_inactive_timeline")?
.load()
.await;
let tline = tenant
.create_empty_timeline(TIMELINE_ID, Lsn(0), DEFAULT_PG_VERSION, &ctx)?
.initialize(&ctx)?;
make_some_layers(tline.as_ref(), Lsn(0x20)).await?;
tenant
.branch_timeline(&tline, NEW_TIMELINE_ID, Some(Lsn(0x40)), &ctx)
.await?;
let newtline = tenant
.get_timeline(NEW_TIMELINE_ID, true)
.expect("Should have a local timeline");
make_some_layers(newtline.as_ref(), Lsn(0x60)).await?;
tline.set_state(TimelineState::Broken);
tenant
.gc_iteration(Some(TIMELINE_ID), 0x10, Duration::ZERO, &ctx)
.await?;
assert_eq!(
newtline.get(*TEST_KEY, Lsn(0x50), &ctx).await?,
TEST_IMG(&format!("foo at {}", Lsn(0x40)))
);
let branchpoints = &tline.gc_info.read().unwrap().retain_lsns;
assert_eq!(branchpoints.len(), 1);
assert_eq!(branchpoints[0], Lsn(0x40));
Ok(())
}
#[tokio::test]
async fn test_retain_data_in_parent_which_is_needed_for_child() -> anyhow::Result<()> {
let (tenant, ctx) =

View File

@@ -51,6 +51,9 @@ where
///
/// A "cursor" for efficiently reading multiple pages from a BlockReader
///
/// A cursor caches the last accessed page, allowing for faster access if the
/// same block is accessed repeatedly.
///
/// You can access the last page with `*cursor`. 'read_blk' returns 'self', so
/// that in many cases you can use a BlockCursor as a drop-in replacement for
/// the underlying BlockReader. For example:
@@ -70,6 +73,8 @@ where
R: BlockReader,
{
reader: R,
/// last accessed page
cache: Option<(u32, R::BlockLease)>,
}
impl<R> BlockCursor<R>
@@ -77,13 +82,40 @@ where
R: BlockReader,
{
pub fn new(reader: R) -> Self {
BlockCursor { reader }
BlockCursor {
reader,
cache: None,
}
}
pub fn read_blk(&mut self, blknum: u32) -> Result<R::BlockLease, std::io::Error> {
self.reader.read_blk(blknum)
pub fn read_blk(&mut self, blknum: u32) -> Result<&Self, std::io::Error> {
// Fast return if this is the same block as before
if let Some((cached_blk, _buf)) = &self.cache {
if *cached_blk == blknum {
return Ok(self);
}
}
// Read the block from the underlying reader, and cache it
self.cache = None;
let buf = self.reader.read_blk(blknum)?;
self.cache = Some((blknum, buf));
Ok(self)
}
}
impl<R> Deref for BlockCursor<R>
where
R: BlockReader,
{
type Target = [u8; PAGE_SZ];
fn deref(&self) -> &<Self as Deref>::Target {
&self.cache.as_ref().unwrap().1
}
}
static NEXT_ID: AtomicU64 = AtomicU64::new(1);
/// An adapter for reading a (virtual) file using the page cache.

View File

@@ -222,6 +222,48 @@ impl TenantConfOpt {
eviction_policy: self.eviction_policy.unwrap_or(global_conf.eviction_policy),
}
}
pub fn update(&mut self, other: &TenantConfOpt) {
if let Some(checkpoint_distance) = other.checkpoint_distance {
self.checkpoint_distance = Some(checkpoint_distance);
}
if let Some(checkpoint_timeout) = other.checkpoint_timeout {
self.checkpoint_timeout = Some(checkpoint_timeout);
}
if let Some(compaction_target_size) = other.compaction_target_size {
self.compaction_target_size = Some(compaction_target_size);
}
if let Some(compaction_period) = other.compaction_period {
self.compaction_period = Some(compaction_period);
}
if let Some(compaction_threshold) = other.compaction_threshold {
self.compaction_threshold = Some(compaction_threshold);
}
if let Some(gc_horizon) = other.gc_horizon {
self.gc_horizon = Some(gc_horizon);
}
if let Some(gc_period) = other.gc_period {
self.gc_period = Some(gc_period);
}
if let Some(image_creation_threshold) = other.image_creation_threshold {
self.image_creation_threshold = Some(image_creation_threshold);
}
if let Some(pitr_interval) = other.pitr_interval {
self.pitr_interval = Some(pitr_interval);
}
if let Some(walreceiver_connect_timeout) = other.walreceiver_connect_timeout {
self.walreceiver_connect_timeout = Some(walreceiver_connect_timeout);
}
if let Some(lagging_wal_timeout) = other.lagging_wal_timeout {
self.lagging_wal_timeout = Some(lagging_wal_timeout);
}
if let Some(max_lsn_wal_lag) = other.max_lsn_wal_lag {
self.max_lsn_wal_lag = Some(max_lsn_wal_lag);
}
if let Some(trace_read_requests) = other.trace_read_requests {
self.trace_read_requests = Some(trace_read_requests);
}
}
}
impl Default for TenantConf {

View File

@@ -2,7 +2,9 @@
//! used to keep in-memory layers spilled on disk.
use crate::config::PageServerConf;
use crate::page_cache::{self, ReadBufResult, WriteBufResult, PAGE_SZ};
use crate::page_cache;
use crate::page_cache::PAGE_SZ;
use crate::page_cache::{ReadBufResult, WriteBufResult};
use crate::tenant::blob_io::BlobWriter;
use crate::tenant::block_io::BlockReader;
use crate::virtual_file::VirtualFile;
@@ -425,6 +427,7 @@ mod tests {
let actual = cursor.read_blob(pos)?;
assert_eq!(actual, expected);
}
drop(cursor);
// Test a large blob that spans multiple pages
let mut large_data = Vec::new();

View File

@@ -154,7 +154,11 @@ where
expected: &Arc<L>,
new: Arc<L>,
) -> anyhow::Result<Replacement<Arc<L>>> {
fail::fail_point!("layermap-replace-notfound", |_| Ok(Replacement::NotFound));
fail::fail_point!("layermap-replace-notfound", |_| Ok(
// this is not what happens if an L0 layer was not found a anyhow error but perhaps
// that should be changed. this is good enough to show a replacement failure.
Replacement::NotFound
));
self.layer_map.replace_historic_noflush(expected, new)
}
@@ -336,15 +340,12 @@ where
let l0_index = if expected_l0 {
// find the index in case replace worked, we need to replace that as well
let pos = self
.l0_delta_layers
.iter()
.position(|slot| Self::compare_arced_layers(slot, expected));
if pos.is_none() {
return Ok(Replacement::NotFound);
}
pos
Some(
self.l0_delta_layers
.iter()
.position(|slot| Self::compare_arced_layers(slot, expected))
.ok_or_else(|| anyhow::anyhow!("existing l0 delta layer was not found"))?,
)
} else {
None
};
@@ -803,26 +804,6 @@ mod tests {
)
}
#[test]
fn replacing_missing_l0_is_notfound() {
// original impl had an oversight, and L0 was an anyhow::Error. anyhow::Error should
// however only happen for precondition failures.
let layer = "000000000000000000000000000000000000-FFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFF__0000000053423C21-0000000053424D69";
let layer = LayerFileName::from_str(layer).unwrap();
let layer = LayerDescriptor::from(layer);
// same skeletan construction; see scenario below
let not_found: Arc<dyn Layer> = Arc::new(layer.clone());
let new_version: Arc<dyn Layer> = Arc::new(layer);
let mut map = LayerMap::default();
let res = map.batch_update().replace_historic(&not_found, new_version);
assert!(matches!(res, Ok(Replacement::NotFound)), "{res:?}");
}
fn l0_delta_layers_updated_scenario(layer_name: &str, expected_l0: bool) {
let name = LayerFileName::from_str(layer_name).unwrap();
let skeleton = LayerDescriptor::from(name);
@@ -832,8 +813,7 @@ mod tests {
let mut map = LayerMap::default();
// two disjoint Arcs in different lifecycle phases. even if it seems they must be the
// same layer, we use LayerMap::compare_arced_layers as the identity of layers.
// two disjoint Arcs in different lifecycle phases.
assert!(!LayerMap::compare_arced_layers(&remote, &downloaded));
let expected_in_counts = (1, usize::from(expected_l0));

View File

@@ -289,7 +289,7 @@ pub async fn set_new_tenant_config(
conf: &'static PageServerConf,
new_tenant_conf: TenantConfOpt,
tenant_id: TenantId,
) -> Result<(), TenantStateError> {
) -> anyhow::Result<()> {
info!("configuring tenant {tenant_id}");
let tenant = get_tenant(tenant_id, true).await?;
@@ -306,20 +306,16 @@ pub async fn set_new_tenant_config(
/// Gets the tenant from the in-memory data, erroring if it's absent or is not fitting to the query.
/// `active_only = true` allows to query only tenants that are ready for operations, erroring on other kinds of tenants.
pub async fn get_tenant(
tenant_id: TenantId,
active_only: bool,
) -> Result<Arc<Tenant>, TenantStateError> {
pub async fn get_tenant(tenant_id: TenantId, active_only: bool) -> anyhow::Result<Arc<Tenant>> {
let m = TENANTS.read().await;
let tenant = m
.get(&tenant_id)
.ok_or(TenantStateError::NotFound(tenant_id))?;
.with_context(|| format!("Tenant {tenant_id} not found in the local state"))?;
if active_only && !tenant.is_active() {
tracing::warn!(
anyhow::bail!(
"Tenant {tenant_id} is not active. Current state: {:?}",
tenant.current_state()
);
Err(TenantStateError::NotActive(tenant_id))
)
} else {
Ok(Arc::clone(tenant))
}
@@ -329,28 +325,21 @@ pub async fn delete_timeline(
tenant_id: TenantId,
timeline_id: TimelineId,
ctx: &RequestContext,
) -> Result<(), TenantStateError> {
let tenant = get_tenant(tenant_id, true).await?;
tenant.delete_timeline(timeline_id, ctx).await?;
Ok(())
}
) -> anyhow::Result<()> {
match get_tenant(tenant_id, true).await {
Ok(tenant) => {
tenant.delete_timeline(timeline_id, ctx).await?;
}
Err(e) => anyhow::bail!("Cannot access tenant {tenant_id} in local tenant state: {e:?}"),
}
#[derive(Debug, thiserror::Error)]
pub enum TenantStateError {
#[error("Tenant {0} not found")]
NotFound(TenantId),
#[error("Tenant {0} is stopping")]
IsStopping(TenantId),
#[error("Tenant {0} is not active")]
NotActive(TenantId),
#[error(transparent)]
Other(#[from] anyhow::Error),
Ok(())
}
pub async fn detach_tenant(
conf: &'static PageServerConf,
tenant_id: TenantId,
) -> Result<(), TenantStateError> {
) -> anyhow::Result<()> {
remove_tenant_from_memory(tenant_id, async {
let local_tenant_directory = conf.tenant_path(&tenant_id);
fs::remove_dir_all(&local_tenant_directory)
@@ -390,7 +379,7 @@ pub async fn load_tenant(
pub async fn ignore_tenant(
conf: &'static PageServerConf,
tenant_id: TenantId,
) -> Result<(), TenantStateError> {
) -> anyhow::Result<()> {
remove_tenant_from_memory(tenant_id, async {
let ignore_mark_file = conf.tenant_ignore_mark_file_path(tenant_id);
fs::File::create(&ignore_mark_file)
@@ -500,7 +489,7 @@ where
async fn remove_tenant_from_memory<V, F>(
tenant_id: TenantId,
tenant_cleanup: F,
) -> Result<V, TenantStateError>
) -> anyhow::Result<V>
where
F: std::future::Future<Output = anyhow::Result<V>>,
{
@@ -516,9 +505,11 @@ where
| TenantState::Loading
| TenantState::Broken
| TenantState::Active => tenant.set_stopping(),
TenantState::Stopping => return Err(TenantStateError::IsStopping(tenant_id)),
TenantState::Stopping => {
anyhow::bail!("Tenant {tenant_id} is stopping already")
}
},
None => return Err(TenantStateError::NotFound(tenant_id)),
None => anyhow::bail!("Tenant not found for id {tenant_id}"),
}
}
@@ -541,15 +532,10 @@ where
Err(e) => {
let tenants_accessor = TENANTS.read().await;
match tenants_accessor.get(&tenant_id) {
Some(tenant) => {
tenant.set_broken(&e.to_string());
}
None => {
warn!("Tenant {tenant_id} got removed from memory");
return Err(TenantStateError::NotFound(tenant_id));
}
Some(tenant) => tenant.set_broken(&e.to_string()),
None => warn!("Tenant {tenant_id} got removed from memory"),
}
Err(TenantStateError::Other(e))
Err(e)
}
}
}
@@ -569,7 +555,7 @@ pub async fn immediate_gc(
let tenant = guard
.get(&tenant_id)
.map(Arc::clone)
.with_context(|| format!("tenant {tenant_id}"))
.with_context(|| format!("Tenant {tenant_id} not found"))
.map_err(ApiError::NotFound)?;
let gc_horizon = gc_req.gc_horizon.unwrap_or_else(|| tenant.get_gc_horizon());
@@ -619,7 +605,7 @@ pub async fn immediate_compact(
let tenant = guard
.get(&tenant_id)
.map(Arc::clone)
.with_context(|| format!("tenant {tenant_id}"))
.with_context(|| format!("Tenant {tenant_id} not found"))
.map_err(ApiError::NotFound)?;
let timeline = tenant

View File

@@ -218,10 +218,9 @@ use tracing::{debug, info, warn};
use tracing::{info_span, Instrument};
use utils::lsn::Lsn;
use crate::metrics::{
MeasureRemoteOp, RemoteOpFileKind, RemoteOpKind, RemoteTimelineClientMetrics,
REMOTE_ONDEMAND_DOWNLOADED_BYTES, REMOTE_ONDEMAND_DOWNLOADED_LAYERS,
};
use crate::metrics::RemoteOpFileKind;
use crate::metrics::RemoteOpKind;
use crate::metrics::{MeasureRemoteOp, RemoteTimelineClientMetrics};
use crate::tenant::remote_timeline_client::index::LayerFileMetadata;
use crate::{
config::PageServerConf,
@@ -447,10 +446,6 @@ impl RemoteTimelineClient {
);
}
}
REMOTE_ONDEMAND_DOWNLOADED_LAYERS.inc();
REMOTE_ONDEMAND_DOWNLOADED_BYTES.inc_by(downloaded_size);
Ok(downloaded_size)
}

View File

@@ -6,13 +6,11 @@
use std::collections::HashSet;
use std::future::Future;
use std::path::Path;
use std::time::Duration;
use anyhow::{anyhow, Context};
use tokio::fs;
use tokio::io::AsyncWriteExt;
use tracing::{info, warn};
use tracing::{error, info, warn};
use crate::config::PageServerConf;
use crate::tenant::storage_layer::LayerFileName;
@@ -28,8 +26,6 @@ async fn fsync_path(path: impl AsRef<std::path::Path>) -> Result<(), std::io::Er
fs::File::open(path).await?.sync_all().await
}
static MAX_DOWNLOAD_DURATION: Duration = Duration::from_secs(120);
///
/// If 'metadata' is given, we will validate that the downloaded file's size matches that
/// in the metadata. (In the future, we might do more cross-checks, like CRC validation)
@@ -68,28 +64,22 @@ pub async fn download_layer_file<'a>(
// TODO: this doesn't use the cached fd for some reason?
let mut destination_file = fs::File::create(&temp_file_path).await.with_context(|| {
format!(
"create a destination file for layer '{}'",
"Failed to create a destination file for layer '{}'",
temp_file_path.display()
)
})
.map_err(DownloadError::Other)?;
let mut download = storage.download(&remote_path).await.with_context(|| {
format!(
"open a download stream for layer with remote storage path '{remote_path:?}'"
"Failed to open a download stream for layer with remote storage path '{remote_path:?}'"
)
})
.map_err(DownloadError::Other)?;
let bytes_amount = tokio::time::timeout(MAX_DOWNLOAD_DURATION, tokio::io::copy(&mut download.download_stream, &mut destination_file))
.await
.map_err(|e| DownloadError::Other(anyhow::anyhow!("Timed out {:?}", e)))?
.with_context(|| {
format!("Failed to download layer with remote storage path '{remote_path:?}' into file {temp_file_path:?}")
})
.map_err(DownloadError::Other)?;
let bytes_amount = tokio::io::copy(&mut download.download_stream, &mut destination_file).await.with_context(|| {
format!("Failed to download layer with remote storage path '{remote_path:?}' into file {temp_file_path:?}")
})
.map_err(DownloadError::Other)?;
Ok((destination_file, bytes_amount))
},
&format!("download {remote_path:?}"),
).await?;
@@ -310,7 +300,7 @@ where
}
Err(DownloadError::Other(ref err)) => {
// Operation failed FAILED_DOWNLOAD_RETRIES times. Time to give up.
warn!("{description} still failed after {attempts} retries, giving up: {err:?}");
error!("{description} still failed after {attempts} retries, giving up: {err:?}");
return result;
}
}

View File

@@ -1246,11 +1246,6 @@ impl Timeline {
state,
};
info!(
"initialized lsrlsn to {}/{}",
disk_consistent_lsn,
metadata.prev_record_lsn().unwrap_or(Lsn(0))
);
result.repartition_threshold = result.get_checkpoint_distance() / 10;
result
.metrics
@@ -1340,7 +1335,6 @@ impl Timeline {
lagging_wal_timeout,
max_lsn_wal_lag,
crate::config::SAFEKEEPER_AUTH_TOKEN.get().cloned(),
self.conf.availability_zone.clone(),
background_ctx,
);
}
@@ -2721,22 +2715,10 @@ impl Timeline {
) -> Result<HashMap<LayerFileName, LayerFileMetadata>, PageReconstructError> {
let timer = self.metrics.create_images_time_histo.start_timer();
let mut image_layers: Vec<ImageLayer> = Vec::new();
// We need to avoid holes between generated image layers.
// Otherwise LayerMap::image_layer_exists will return false if key range of some layer is covered by more than one
// image layer with hole between them. In this case such layer can not be utilized by GC.
//
// How such hole between partitions can appear?
// if we have relation with relid=1 and size 100 and relation with relid=2 with size 200 then result of
// KeySpace::partition may contain partitions <100000000..100000099> and <200000000..200000199>.
// If there is delta layer <100000000..300000000> then it never be garbage collected because
// image layers <100000000..100000099> and <200000000..200000199> are not completely covering it.
let mut start = Key::MIN;
for partition in partitioning.parts.iter() {
let img_range = start..partition.ranges.last().unwrap().end;
start = img_range.end;
if force || self.time_for_new_image_layer(partition, lsn)? {
let img_range =
partition.ranges.first().unwrap().start..partition.ranges.last().unwrap().end;
let mut image_layer_writer = ImageLayerWriter::new(
self.conf,
self.timeline_id,
@@ -2750,6 +2732,7 @@ impl Timeline {
"failpoint image-layer-writer-fail-before-finish"
)))
});
for range in &partition.ranges {
let mut key = range.start;
while key < range.end {
@@ -3164,7 +3147,9 @@ impl Timeline {
}
fail_point!("delta-layer-writer-fail-before-finish", |_| {
Err(anyhow::anyhow!("failpoint delta-layer-writer-fail-before-finish").into())
return Err(
anyhow::anyhow!("failpoint delta-layer-writer-fail-before-finish").into(),
);
});
writer.as_mut().unwrap().put_value(key, lsn, value)?;
@@ -3834,7 +3819,7 @@ impl Timeline {
remote_layer.ongoing_download.close();
} else {
// Keep semaphore open. We'll drop the permit at the end of the function.
error!("on-demand download failed: {:?}", result.as_ref().unwrap_err());
info!("on-demand download failed: {:?}", result.as_ref().unwrap_err());
}
// Don't treat it as an error if the task that triggered the download

View File

@@ -45,7 +45,6 @@ pub fn spawn_connection_manager_task(
lagging_wal_timeout: Duration,
max_lsn_wal_lag: NonZeroU64,
auth_token: Option<Arc<String>>,
availability_zone: Option<String>,
ctx: RequestContext,
) {
let mut broker_client = get_broker_client().clone();
@@ -68,7 +67,6 @@ pub fn spawn_connection_manager_task(
lagging_wal_timeout,
max_lsn_wal_lag,
auth_token,
availability_zone,
);
loop {
select! {
@@ -336,7 +334,6 @@ struct WalreceiverState {
/// Data about all timelines, available for connection, fetched from storage broker, grouped by their corresponding safekeeper node id.
wal_stream_candidates: HashMap<NodeId, BrokerSkTimeline>,
auth_token: Option<Arc<String>>,
availability_zone: Option<String>,
}
/// Current connection data.
@@ -384,7 +381,6 @@ impl WalreceiverState {
lagging_wal_timeout: Duration,
max_lsn_wal_lag: NonZeroU64,
auth_token: Option<Arc<String>>,
availability_zone: Option<String>,
) -> Self {
let id = TenantTimelineId {
tenant_id: timeline.tenant_id,
@@ -400,7 +396,6 @@ impl WalreceiverState {
wal_stream_candidates: HashMap::new(),
wal_connection_retries: HashMap::new(),
auth_token,
availability_zone,
}
}
@@ -745,7 +740,6 @@ impl WalreceiverState {
None => None,
Some(x) => Some(x),
},
self.availability_zone.as_deref(),
) {
Ok(connstr) => Some((*sk_id, info, connstr)),
Err(e) => {
@@ -830,24 +824,17 @@ fn wal_stream_connection_config(
}: TenantTimelineId,
listen_pg_addr_str: &str,
auth_token: Option<&str>,
availability_zone: Option<&str>,
) -> anyhow::Result<PgConnectionConfig> {
let (host, port) =
parse_host_port(listen_pg_addr_str).context("Unable to parse listen_pg_addr_str")?;
let port = port.unwrap_or(5432);
let mut connstr = PgConnectionConfig::new_host_port(host, port)
Ok(PgConnectionConfig::new_host_port(host, port)
.extend_options([
"-c".to_owned(),
format!("timeline_id={}", timeline_id),
format!("tenant_id={}", tenant_id),
])
.set_password(auth_token.map(|s| s.to_owned()));
if let Some(availability_zone) = availability_zone {
connstr = connstr.extend_options([format!("availability_zone={}", availability_zone)]);
}
Ok(connstr)
.set_password(auth_token.map(|s| s.to_owned())))
}
#[cfg(test)]
@@ -1286,7 +1273,6 @@ mod tests {
wal_stream_candidates: HashMap::new(),
wal_connection_retries: HashMap::new(),
auth_token: None,
availability_zone: None,
}
}
}

View File

@@ -33,11 +33,10 @@ use crate::{
walingest::WalIngest,
walrecord::DecodedWALRecord,
};
use postgres_backend::is_expected_io_error;
use postgres_connection::PgConnectionConfig;
use postgres_ffi::waldecoder::WalStreamDecoder;
use pq_proto::ReplicationFeedback;
use utils::lsn::Lsn;
use utils::{lsn::Lsn, postgres_backend_async::is_expected_io_error};
/// Status of the connection.
#[derive(Debug, Clone, Copy)]
@@ -354,7 +353,7 @@ pub async fn handle_walreceiver_connection(
debug!("neon_status_update {status_update:?}");
let mut data = BytesMut::new();
status_update.serialize(&mut data);
status_update.serialize(&mut data)?;
physical_stream
.as_mut()
.zenith_status_update(data.len() as u64, &data)
@@ -435,8 +434,8 @@ fn ignore_expected_errors(pg_error: postgres::Error) -> anyhow::Result<postgres:
{
return Ok(pg_error);
} else if let Some(db_error) = pg_error.as_db_error() {
if db_error.code() == &SqlState::SUCCESSFUL_COMPLETION
&& db_error.message().contains("ending streaming")
if db_error.code() == &SqlState::CONNECTION_FAILURE
&& db_error.message().contains("end streaming")
{
return Ok(pg_error);
}

View File

@@ -23,11 +23,13 @@ use bytes::{BufMut, Bytes, BytesMut};
use nix::poll::*;
use serde::Serialize;
use std::collections::VecDeque;
use std::fs::OpenOptions;
use std::io::prelude::*;
use std::io::{Error, ErrorKind};
use std::ops::{Deref, DerefMut};
use std::os::unix::io::{AsRawFd, RawFd};
use std::os::unix::prelude::CommandExt;
use std::path::PathBuf;
use std::process::Stdio;
use std::process::{Child, ChildStderr, ChildStdin, ChildStdout, Command};
use std::sync::{Mutex, MutexGuard};
@@ -254,53 +256,52 @@ impl PostgresRedoManager {
pg_version: u32,
) -> Result<Bytes, WalRedoError> {
let (rel, blknum) = key_to_rel_block(key).or(Err(WalRedoError::InvalidRecord))?;
const MAX_RETRY_ATTEMPTS: u32 = 1;
let start_time = Instant::now();
let mut n_attempts = 0u32;
loop {
let mut proc = self.stdin.lock().unwrap();
let lock_time = Instant::now();
// launch the WAL redo process on first use
if proc.is_none() {
self.launch(&mut proc, pg_version)?;
}
WAL_REDO_WAIT_TIME.observe(lock_time.duration_since(start_time).as_secs_f64());
let mut proc = self.stdin.lock().unwrap();
let lock_time = Instant::now();
// Relational WAL records are applied using wal-redo-postgres
let buf_tag = BufferTag { rel, blknum };
let result = self
.apply_wal_records(proc, buf_tag, &base_img, records, wal_redo_timeout)
.map_err(WalRedoError::IoError);
// launch the WAL redo process on first use
if proc.is_none() {
self.launch(&mut proc, pg_version)?;
}
WAL_REDO_WAIT_TIME.observe(lock_time.duration_since(start_time).as_secs_f64());
let end_time = Instant::now();
let duration = end_time.duration_since(lock_time);
// Relational WAL records are applied using wal-redo-postgres
let buf_tag = BufferTag { rel, blknum };
let result = self
.apply_wal_records(proc, buf_tag, base_img, records, wal_redo_timeout)
.map_err(WalRedoError::IoError);
let len = records.len();
let nbytes = records.iter().fold(0, |acumulator, record| {
acumulator
+ match &record.1 {
NeonWalRecord::Postgres { rec, .. } => rec.len(),
_ => unreachable!("Only PostgreSQL records are accepted in this batch"),
}
});
let end_time = Instant::now();
let duration = end_time.duration_since(lock_time);
WAL_REDO_TIME.observe(duration.as_secs_f64());
WAL_REDO_RECORDS_HISTOGRAM.observe(len as f64);
WAL_REDO_BYTES_HISTOGRAM.observe(nbytes as f64);
let len = records.len();
let nbytes = records.iter().fold(0, |acumulator, record| {
acumulator
+ match &record.1 {
NeonWalRecord::Postgres { rec, .. } => rec.len(),
_ => unreachable!("Only PostgreSQL records are accepted in this batch"),
}
});
debug!(
"postgres applied {} WAL records ({} bytes) in {} us to reconstruct page image at LSN {}",
len,
nbytes,
duration.as_micros(),
lsn
);
WAL_REDO_TIME.observe(duration.as_secs_f64());
WAL_REDO_RECORDS_HISTOGRAM.observe(len as f64);
WAL_REDO_BYTES_HISTOGRAM.observe(nbytes as f64);
// If something went wrong, don't try to reuse the process. Kill it, and
// next request will launch a new one.
if result.is_err() {
error!(
debug!(
"postgres applied {} WAL records ({} bytes) in {} us to reconstruct page image at LSN {}",
len,
nbytes,
duration.as_micros(),
lsn
);
// If something went wrong, don't try to reuse the process. Kill it, and
// next request will launch a new one.
if result.is_err() {
error!(
"error applying {} WAL records {}..{} ({} bytes) to base image with LSN {} to reconstruct page image at LSN {}",
records.len(),
records.first().map(|p| p.0).unwrap_or(Lsn(0)),
@@ -309,28 +310,24 @@ impl PostgresRedoManager {
base_img_lsn,
lsn
);
// self.stdin only holds stdin & stderr as_raw_fd().
// Dropping it as part of take() doesn't close them.
// The owning objects (ChildStdout and ChildStderr) are stored in
// self.stdout and self.stderr, respsectively.
// We intentionally keep them open here to avoid a race between
// currently running `apply_wal_records()` and a `launch()` call
// after we return here.
// The currently running `apply_wal_records()` must not read from
// the newly launched process.
// By keeping self.stdout and self.stderr open here, `launch()` will
// get other file descriptors for the new child's stdout and stderr,
// and hence the current `apply_wal_records()` calls will observe
// `output.stdout.as_raw_fd() != stdout_fd` .
if let Some(proc) = self.stdin.lock().unwrap().take() {
proc.child.kill_and_wait();
}
}
n_attempts += 1;
if n_attempts > MAX_RETRY_ATTEMPTS || result.is_ok() {
return result;
// self.stdin only holds stdin & stderr as_raw_fd().
// Dropping it as part of take() doesn't close them.
// The owning objects (ChildStdout and ChildStderr) are stored in
// self.stdout and self.stderr, respsectively.
// We intentionally keep them open here to avoid a race between
// currently running `apply_wal_records()` and a `launch()` call
// after we return here.
// The currently running `apply_wal_records()` must not read from
// the newly launched process.
// By keeping self.stdout and self.stderr open here, `launch()` will
// get other file descriptors for the new child's stdout and stderr,
// and hence the current `apply_wal_records()` calls will observe
// `output.stdout.as_raw_fd() != stdout_fd` .
if let Some(proc) = self.stdin.lock().unwrap().take() {
proc.child.kill_and_wait();
}
}
result
}
///
@@ -637,26 +634,26 @@ impl PostgresRedoManager {
input: &mut MutexGuard<Option<ProcessInput>>,
pg_version: u32,
) -> Result<(), Error> {
// Previous versions of wal-redo required data directory and that directories
// occupied some space on disk. Remove it if we face it.
//
// This code could be dropped after one release cycle.
let legacy_datadir = path_with_suffix_extension(
// FIXME: We need a dummy Postgres cluster to run the process in. Currently, we
// just create one with constant name. That fails if you try to launch more than
// one WAL redo manager concurrently.
let datadir = path_with_suffix_extension(
self.conf
.tenant_path(&self.tenant_id)
.join("wal-redo-datadir"),
TEMP_FILE_SUFFIX,
);
if legacy_datadir.exists() {
info!("legacy wal-redo datadir {legacy_datadir:?} exists, removing");
fs::remove_dir_all(&legacy_datadir).map_err(|e| {
// Create empty data directory for wal-redo postgres, deleting old one first.
if datadir.exists() {
info!("old temporary datadir {datadir:?} exists, removing");
fs::remove_dir_all(&datadir).map_err(|e| {
Error::new(
e.kind(),
format!("legacy wal-redo datadir {legacy_datadir:?} removal failure: {e}"),
format!("Old temporary dir {datadir:?} removal failure: {e}"),
)
})?;
}
let pg_bin_dir_path = self
.conf
.pg_bin_dir(pg_version)
@@ -666,6 +663,35 @@ impl PostgresRedoManager {
.pg_lib_dir(pg_version)
.map_err(|e| Error::new(ErrorKind::Other, format!("incorrect pg_lib_dir path: {e}")))?;
info!("running initdb in {}", datadir.display());
let initdb = Command::new(pg_bin_dir_path.join("initdb"))
.args(["-D", &datadir.to_string_lossy()])
.arg("-N")
.env_clear()
.env("LD_LIBRARY_PATH", &pg_lib_dir_path)
.env("DYLD_LIBRARY_PATH", &pg_lib_dir_path) // macOS
.close_fds()
.output()
.map_err(|e| Error::new(e.kind(), format!("failed to execute initdb: {e}")))?;
if !initdb.status.success() {
return Err(Error::new(
ErrorKind::Other,
format!(
"initdb failed\nstdout: {}\nstderr:\n{}",
String::from_utf8_lossy(&initdb.stdout),
String::from_utf8_lossy(&initdb.stderr)
),
));
} else {
// Limit shared cache for wal-redo-postgres
let mut config = OpenOptions::new()
.append(true)
.open(PathBuf::from(&datadir).join("postgresql.conf"))?;
config.write_all(b"shared_buffers=128kB\n")?;
config.write_all(b"fsync=off\n")?;
}
// Start postgres itself
let child = Command::new(pg_bin_dir_path.join("postgres"))
.arg("--wal-redo")
@@ -675,6 +701,7 @@ impl PostgresRedoManager {
.env_clear()
.env("LD_LIBRARY_PATH", &pg_lib_dir_path)
.env("DYLD_LIBRARY_PATH", &pg_lib_dir_path)
.env("PGDATA", &datadir)
// The redo process is not trusted, and runs in seccomp mode that
// doesn't allow it to open any files. We have to also make sure it
// doesn't inherit any file descriptors from the pageserver, that
@@ -744,7 +771,7 @@ impl PostgresRedoManager {
&self,
mut input: MutexGuard<Option<ProcessInput>>,
tag: BufferTag,
base_img: &Option<Bytes>,
base_img: Option<Bytes>,
records: &[(Lsn, NeonWalRecord)],
wal_redo_timeout: Duration,
) -> Result<Bytes, std::io::Error> {
@@ -760,7 +787,7 @@ impl PostgresRedoManager {
let mut writebuf: Vec<u8> = Vec::with_capacity((BLCKSZ as usize) * 3);
build_begin_redo_for_block_msg(tag, &mut writebuf);
if let Some(img) = base_img {
build_push_page_msg(tag, img, &mut writebuf);
build_push_page_msg(tag, &img, &mut writebuf);
}
for (lsn, rec) in records.iter() {
if let NeonWalRecord::Postgres {

View File

@@ -13,14 +13,13 @@ OBJS = \
walproposer.o \
walproposer_utils.o
PG_CPPFLAGS = -I$(libpq_srcdir)
SHLIB_LINK_INTERNAL = $(libpq)
PG_CPPFLAGS = -I$(libpq_srcdir) -DSIMLIB
PG_LIBS = $(libpq)
EXTENSION = neon
DATA = neon--1.0.sql
PGFILEDESC = "neon - cloud storage for PostgreSQL"
PG_CONFIG = pg_config
PGXS := $(shell $(PG_CONFIG) --pgxs)
include $(PGXS)

View File

@@ -32,9 +32,6 @@
#define PageStoreTrace DEBUG5
#define MAX_RECONNECT_ATTEMPTS 5
#define RECONNECT_INTERVAL_USEC 1000000
bool connected = false;
PGconn *pageserver_conn = NULL;
@@ -55,8 +52,8 @@ int readahead_buffer_size = 128;
static void pageserver_flush(void);
static bool
pageserver_connect(int elevel)
static void
pageserver_connect()
{
char *query;
int ret;
@@ -72,11 +69,10 @@ pageserver_connect(int elevel)
PQfinish(pageserver_conn);
pageserver_conn = NULL;
ereport(elevel,
ereport(ERROR,
(errcode(ERRCODE_SQLCLIENT_UNABLE_TO_ESTABLISH_SQLCONNECTION),
errmsg(NEON_TAG "could not establish connection to pageserver"),
errdetail_internal("%s", msg)));
return false;
}
query = psprintf("pagestream %s %s", neon_tenant, neon_timeline);
@@ -85,8 +81,7 @@ pageserver_connect(int elevel)
{
PQfinish(pageserver_conn);
pageserver_conn = NULL;
neon_log(elevel, "could not send pagestream command to pageserver");
return false;
neon_log(ERROR, "could not send pagestream command to pageserver");
}
pageserver_conn_wes = CreateWaitEventSet(TopMemoryContext, 3);
@@ -118,9 +113,8 @@ pageserver_connect(int elevel)
FreeWaitEventSet(pageserver_conn_wes);
pageserver_conn_wes = NULL;
neon_log(elevel, "could not complete handshake with pageserver: %s",
neon_log(ERROR, "could not complete handshake with pageserver: %s",
msg);
return false;
}
}
}
@@ -128,7 +122,6 @@ pageserver_connect(int elevel)
neon_log(LOG, "libpagestore: connected to '%s'", page_server_connstring_raw);
connected = true;
return true;
}
/*
@@ -156,11 +149,8 @@ retry:
if (event.events & WL_SOCKET_READABLE)
{
if (!PQconsumeInput(pageserver_conn))
{
neon_log(LOG, "could not get response from pageserver: %s",
neon_log(ERROR, "could not get response from pageserver: %s",
PQerrorMessage(pageserver_conn));
return -1;
}
}
goto retry;
@@ -200,62 +190,31 @@ static void
pageserver_send(NeonRequest * request)
{
StringInfoData req_buff;
int n_reconnect_attempts = 0;
/* If the connection was lost for some reason, reconnect */
if (connected && PQstatus(pageserver_conn) == CONNECTION_BAD)
pageserver_disconnect();
if (!connected)
pageserver_connect();
req_buff = nm_pack_request(request);
/*
* If pageserver is stopped, the connections from compute node are broken.
* The compute node doesn't notice that immediately, but it will cause the next request to fail, usually on the next query.
* That causes user-visible errors if pageserver is restarted, or the tenant is moved from one pageserver to another.
* See https://github.com/neondatabase/neon/issues/1138
* So try to reestablish connection in case of failure.
* Send request.
*
* In principle, this could block if the output buffer is full, and we
* should use async mode and check for interrupts while waiting. In
* practice, our requests are small enough to always fit in the output and
* TCP buffer.
*/
while (true)
if (PQputCopyData(pageserver_conn, req_buff.data, req_buff.len) <= 0)
{
if (!connected)
{
if (!pageserver_connect(n_reconnect_attempts < MAX_RECONNECT_ATTEMPTS ? LOG : ERROR))
{
n_reconnect_attempts += 1;
pg_usleep(RECONNECT_INTERVAL_USEC);
continue;
}
}
char *msg = pchomp(PQerrorMessage(pageserver_conn));
/*
* Send request.
*
* In principle, this could block if the output buffer is full, and we
* should use async mode and check for interrupts while waiting. In
* practice, our requests are small enough to always fit in the output and
* TCP buffer.
*/
if (PQputCopyData(pageserver_conn, req_buff.data, req_buff.len) <= 0)
{
char *msg = pchomp(PQerrorMessage(pageserver_conn));
if (n_reconnect_attempts < MAX_RECONNECT_ATTEMPTS)
{
neon_log(LOG, "failed to send page request (try to reconnect): %s", msg);
if (n_reconnect_attempts != 0) /* do not sleep before first reconnect attempt, assuming that pageserver is already restarted */
pg_usleep(RECONNECT_INTERVAL_USEC);
n_reconnect_attempts += 1;
continue;
}
else
{
pageserver_disconnect();
neon_log(ERROR, "failed to send page request: %s", msg);
}
}
break;
pageserver_disconnect();
neon_log(ERROR, "failed to send page request: %s", msg);
}
pfree(req_buff.data);
n_unflushed_requests++;

Some files were not shown because too many files have changed in this diff Show More