Compare commits

..

50 Commits

Author SHA1 Message Date
Arthur Petukhovsky
6a00ad3aab Validate logs 2023-09-18 17:12:40 +00:00
Arthur Petukhovsky
61e6b24cb2 Fix fd leak 2023-09-18 09:54:19 +00:00
Arthur Petukhovsky
44c7d96ed0 Cleanup resources better 2023-09-16 23:10:53 +00:00
Arthur Petukhovsky
10ad3ae4eb Fix WAL page header 2023-09-16 20:50:10 +00:00
Arthur Petukhovsky
eb2886b401 Add test for 1000 WAL messages 2023-09-16 14:58:26 +00:00
Arthur Petukhovsky
0dc262a84a Fix bug in walproposer voting 2023-08-29 14:11:04 +00:00
Arthur Petukhovsky
d801ba7248 Print network config 2023-08-29 14:10:57 +00:00
Arthur Petukhovsky
1effb586ba Make network random unstable 2023-08-29 13:55:36 +00:00
Arthur Petukhovsky
2fd351fd63 Hide debug logs 2023-08-29 10:05:18 +00:00
Arthur Petukhovsky
13e94bf687 Fix truncateLsn bug 2023-08-29 09:03:22 +00:00
Arthur Petukhovsky
41b9750e81 Run many schedules 2023-08-24 23:42:11 +00:00
Arthur Petukhovsky
f8729f046d Fix excessive logs 2023-08-24 17:25:44 +00:00
Arthur Petukhovsky
420d3bc18f Add simulation schedule 2023-08-24 15:24:38 +00:00
Arthur Petukhovsky
33f7877d1b Show simulation time in logs 2023-08-23 10:10:11 +00:00
Arthur Petukhovsky
7de94c959a Support walproposer recovery 2023-08-22 23:15:46 +00:00
Arthur Petukhovsky
731ed3bb64 Support virtual disk in tests 2023-08-17 13:09:55 +00:00
Arthur Petukhovsky
413ce2cfe8 Crash safekeepers 2023-08-17 10:36:23 +00:00
Arthur Petukhovsky
7f36028fab Generate WAL in tests 2023-08-03 16:58:41 +00:00
Arthur Petukhovsky
cb6a8d3fe3 Fix some warnings 2023-07-28 21:37:16 +00:00
Arthur Petukhovsky
095747afc0 Fix walproposer main loop 2023-07-28 21:18:08 +00:00
Arthur Petukhovsky
89bd7ab8a3 Fix read/write in walproposer 2023-07-28 15:14:24 +00:00
Arthur Petukhovsky
5034a8cca0 WIP 2023-07-26 22:51:19 +02:00
Arthur Petukhovsky
55e40d090e Run sync several times 2023-07-25 11:16:47 +00:00
Arthur Petukhovsky
d87e822169 Return LSN from sync safekeepers 2023-07-24 21:15:35 +00:00
Arthur Petukhovsky
296a0cbac2 Add -DSIMLIB 2023-07-21 15:40:47 +00:00
Arthur Petukhovsky
aed14f52d5 Test sync safekeepers 2023-06-03 19:11:28 +00:00
Arthur Petukhovsky
909d7fadb8 Implement simlib sk server 2023-06-02 14:49:55 +00:00
Arthur Petukhovsky
3840d6b18b Clean up C API 2023-06-01 09:38:07 +00:00
Arthur Petukhovsky
65f92232e6 Compile walproposer 2023-05-31 21:06:47 +00:00
Arthur Petukhovsky
0d4f987fc8 Implement full simlib C API 2023-05-31 20:25:25 +00:00
Arthur Petukhovsky
aa0763d49d Run simulator on C code 2023-05-31 16:55:16 +00:00
Arthur Petukhovsky
7b5123edda Fix elog 2023-05-31 15:06:26 +00:00
Arthur Petukhovsky
b6a80bc269 Link postgres to rust statically 2023-05-31 13:19:41 +00:00
Arthur Petukhovsky
ac82b34c64 Create more involved example 2023-05-30 16:43:33 +00:00
Arthur Petukhovsky
a77fc2c5ff Test Rust -> C -> Rust codepath 2023-05-30 16:38:32 +00:00
Arthur Petukhovsky
9ccbec0e14 Spend some time 2023-05-26 14:45:25 +03:00
Arthur Petukhovsky
b55005d2c4 Build simple C func example 2023-05-26 14:44:48 +03:00
Arthur Petukhovsky
6436432a77 Showcase network failures 2023-05-25 12:53:20 +03:00
Arthur Petukhovsky
1b8918e665 Add accept, close and delays to the network 2023-05-25 12:26:57 +03:00
Arthur Petukhovsky
87c9edac7c Add basic support for network delays 2023-05-24 20:28:53 +03:00
Arthur Petukhovsky
5e0550a620 Add os.sleep and os.random 2023-05-24 15:51:30 +03:00
Arthur Petukhovsky
06f493f525 Extract simlib 2023-05-24 13:06:42 +03:00
Arthur Petukhovsky
f6b540ebfe Add initial support for virtual time 2023-05-22 15:00:56 +03:00
Arthur Petukhovsky
83f87af02b Remove sync debug 2023-03-10 00:10:09 +02:00
Arthur Petukhovsky
79823c38cd It looks deterministic now 2023-03-10 00:03:35 +02:00
Arthur Petukhovsky
072fb3d7e9 WIP 2023-03-09 14:59:03 +02:00
Arthur Petukhovsky
f2fb9f6be9 WIP 2023-03-09 14:51:29 +02:00
Arthur Petukhovsky
dd4c8fb568 WIP 2023-03-09 00:51:14 +02:00
Arthur Petukhovsky
9116c01614 WIP 2023-03-08 18:45:13 +02:00
Arthur Petukhovsky
17cd96e022 WIP 2023-03-03 20:33:55 +00:00
232 changed files with 10237 additions and 8163 deletions

View File

@@ -14,3 +14,6 @@ opt-level = 1
[alias]
build_testing = ["build", "--features", "testing"]
[build]
rustflags = ["-C", "default-linker-libraries"]

View File

@@ -91,15 +91,6 @@
tags:
- pageserver
# used in `pageserver.service` template
- name: learn current availability_zone
shell:
cmd: "curl -s http://169.254.169.254/latest/meta-data/placement/availability-zone"
register: ec2_availability_zone
- set_fact:
ec2_availability_zone={{ ec2_availability_zone.stdout }}
- name: upload systemd service definition
ansible.builtin.template:
src: systemd/pageserver.service
@@ -127,7 +118,7 @@
cmd: |
INSTANCE_ID=$(curl -s http://169.254.169.254/latest/meta-data/instance-id)
curl -sfS -H "Authorization: Bearer {{ CONSOLE_API_TOKEN }}" {{ console_mgmt_base_url }}/management/api/v2/pageservers/$INSTANCE_ID | jq '.version = {{ current_version }}' > /tmp/new_version
curl -sfS -H "Authorization: Bearer {{ CONSOLE_API_TOKEN }}" -H "Content-Type: application/json" -X POST -d@/tmp/new_version {{ console_mgmt_base_url }}/management/api/v2/pageservers
curl -sfS -H "Authorization: Bearer {{ CONSOLE_API_TOKEN }}" -X POST -d@/tmp/new_version {{ console_mgmt_base_url }}/management/api/v2/pageservers
tags:
- pageserver
@@ -162,15 +153,6 @@
tags:
- safekeeper
# used in `safekeeper.service` template
- name: learn current availability_zone
shell:
cmd: "curl -s http://169.254.169.254/latest/meta-data/placement/availability-zone"
register: ec2_availability_zone
- set_fact:
ec2_availability_zone={{ ec2_availability_zone.stdout }}
# in the future safekeepers should discover pageservers byself
# but currently use first pageserver that was discovered
- name: set first pageserver var for safekeepers
@@ -206,6 +188,6 @@
cmd: |
INSTANCE_ID=$(curl -s http://169.254.169.254/latest/meta-data/instance-id)
curl -sfS -H "Authorization: Bearer {{ CONSOLE_API_TOKEN }}" {{ console_mgmt_base_url }}/management/api/v2/safekeepers/$INSTANCE_ID | jq '.version = {{ current_version }}' > /tmp/new_version
curl -sfS -H "Authorization: Bearer {{ CONSOLE_API_TOKEN }}" -H "Content-Type: application/json" -X POST -d@/tmp/new_version {{ console_mgmt_base_url }}/management/api/v2/safekeepers
curl -sfS -H "Authorization: Bearer {{ CONSOLE_API_TOKEN }}" -X POST -d@/tmp/new_version {{ console_mgmt_base_url }}/management/api/v2/safekeepers
tags:
- safekeeper

View File

@@ -27,8 +27,6 @@ storage:
ansible_host: i-0cd8d316ecbb715be
pageserver-1.eu-central-1.aws.neon.tech:
ansible_host: i-090044ed3d383fef0
pageserver-2.eu-central-1.aws.neon.tech:
ansible_host: i-033584edf3f4b6742
safekeepers:
hosts:

View File

@@ -26,7 +26,7 @@ EOF
if ! curl -sf -H "Authorization: Bearer {{ CONSOLE_API_TOKEN }}" {{ console_mgmt_base_url }}/management/api/v2/pageservers/${INSTANCE_ID} -o /dev/null; then
# not registered, so register it now
ID=$(curl -sf -X POST -H "Authorization: Bearer {{ CONSOLE_API_TOKEN }}" -H "Content-Type: application/json" {{ console_mgmt_base_url }}/management/api/v2/pageservers -d@/tmp/payload | jq -r '.id')
ID=$(curl -sf -X POST -H "Authorization: Bearer {{ CONSOLE_API_TOKEN }}" {{ console_mgmt_base_url }}/management/api/v2/pageservers -d@/tmp/payload | jq -r '.id')
# init pageserver
sudo -u pageserver /usr/local/bin/pageserver -c "id=${ID}" -c "pg_distrib_dir='/usr/local'" --init -D /storage/pageserver/data

View File

@@ -25,7 +25,7 @@ EOF
if ! curl -sf -H "Authorization: Bearer {{ CONSOLE_API_TOKEN }}" {{ console_mgmt_base_url }}/management/api/v2/safekeepers/${INSTANCE_ID} -o /dev/null; then
# not registered, so register it now
ID=$(curl -sf -X POST -H "Authorization: Bearer {{ CONSOLE_API_TOKEN }}" -H "Content-Type: application/json" {{ console_mgmt_base_url }}/management/api/v2/safekeepers -d@/tmp/payload | jq -r '.id')
ID=$(curl -sf -X POST -H "Authorization: Bearer {{ CONSOLE_API_TOKEN }}" {{ console_mgmt_base_url }}/management/api/v2/safekeepers -d@/tmp/payload | jq -r '.id')
# init safekeeper
sudo -u safekeeper /usr/local/bin/safekeeper --id ${ID} --init -D /storage/safekeeper/data
fi

View File

@@ -6,7 +6,7 @@ After=network.target auditd.service
Type=simple
User=pageserver
Environment=RUST_BACKTRACE=1 NEON_REPO_DIR=/storage/pageserver LD_LIBRARY_PATH=/usr/local/v14/lib SENTRY_DSN={{ SENTRY_URL_PAGESERVER }} SENTRY_ENVIRONMENT={{ sentry_environment }}
ExecStart=/usr/local/bin/pageserver -c "pg_distrib_dir='/usr/local'" -c "listen_pg_addr='0.0.0.0:6400'" -c "listen_http_addr='0.0.0.0:9898'" -c "broker_endpoint='{{ broker_endpoint }}'" -c "availability_zone='{{ ec2_availability_zone }}'" -D /storage/pageserver/data
ExecStart=/usr/local/bin/pageserver -c "pg_distrib_dir='/usr/local'" -c "listen_pg_addr='0.0.0.0:6400'" -c "listen_http_addr='0.0.0.0:9898'" -c "broker_endpoint='{{ broker_endpoint }}'" -D /storage/pageserver/data
ExecReload=/bin/kill -HUP $MAINPID
KillMode=mixed
KillSignal=SIGINT

View File

@@ -6,7 +6,7 @@ After=network.target auditd.service
Type=simple
User=safekeeper
Environment=RUST_BACKTRACE=1 NEON_REPO_DIR=/storage/safekeeper/data LD_LIBRARY_PATH=/usr/local/v14/lib SENTRY_DSN={{ SENTRY_URL_SAFEKEEPER }} SENTRY_ENVIRONMENT={{ sentry_environment }}
ExecStart=/usr/local/bin/safekeeper -l {{ inventory_hostname }}{{ hostname_suffix }}:6500 --listen-http {{ inventory_hostname }}{{ hostname_suffix }}:7676 -D /storage/safekeeper/data --broker-endpoint={{ broker_endpoint }} --remote-storage='{bucket_name="{{bucket_name}}", bucket_region="{{bucket_region}}", prefix_in_bucket="{{ safekeeper_s3_prefix }}"}' --availability-zone={{ ec2_availability_zone }}
ExecStart=/usr/local/bin/safekeeper -l {{ inventory_hostname }}{{ hostname_suffix }}:6500 --listen-http {{ inventory_hostname }}{{ hostname_suffix }}:7676 -D /storage/safekeeper/data --broker-endpoint={{ broker_endpoint }} --remote-storage='{bucket_name="{{bucket_name}}", bucket_region="{{bucket_region}}", prefix_in_bucket="{{ safekeeper_s3_prefix }}"}'
ExecReload=/bin/kill -HUP $MAINPID
KillMode=mixed
KillSignal=SIGINT

View File

@@ -1,22 +1,6 @@
# Helm chart values for neon-proxy-scram.
# This is a YAML-formatted file.
deploymentStrategy:
type: RollingUpdate
rollingUpdate:
maxSurge: 100%
maxUnavailable: 50%
# Delay the kill signal by 7 days (7 * 24 * 60 * 60)
# The pod(s) will stay in Terminating, keeps the existing connections
# but doesn't receive new ones
containerLifecycle:
preStop:
exec:
command: ["/bin/sh", "-c", "sleep 604800"]
terminationGracePeriodSeconds: 604800
image:
repository: neondatabase/neon

View File

@@ -74,12 +74,15 @@ jobs:
- name: Install Python deps
run: ./scripts/pysync
- name: Run ruff to ensure code format
run: poetry run ruff .
- name: Run isort to ensure code format
run: poetry run isort --diff --check .
- name: Run black to ensure code format
run: poetry run black --diff --check .
- name: Run flake8 to ensure code format
run: poetry run flake8 .
- name: Run mypy to check types
run: poetry run mypy .
@@ -548,48 +551,6 @@ jobs:
- name: Cleanup ECR folder
run: rm -rf ~/.ecr
neon-image-depot:
# For testing this will run side-by-side for a few merges.
# This action is not really optimized yet, but gets the job done
runs-on: [ self-hosted, gen3, large ]
needs: [ tag ]
container: 369495373322.dkr.ecr.eu-central-1.amazonaws.com/base:pinned
permissions:
contents: read
id-token: write
steps:
- name: Checkout
uses: actions/checkout@v3
with:
submodules: true
fetch-depth: 0
- name: Setup go
uses: actions/setup-go@v3
with:
go-version: '1.19'
- name: Set up Depot CLI
uses: depot/setup-action@v1
- name: Install Crane & ECR helper
run: go install github.com/awslabs/amazon-ecr-credential-helper/ecr-login/cli/docker-credential-ecr-login@69c85dc22db6511932bbf119e1a0cc5c90c69a7f # v0.6.0
- name: Configure ECR login
run: |
mkdir /github/home/.docker/
echo "{\"credsStore\":\"ecr-login\"}" > /github/home/.docker/config.json
- name: Build and push
uses: depot/build-push-action@v1
with:
# if no depot.json file is at the root of your repo, you must specify the project id
project: nrdv0s4kcs
push: true
tags: 369495373322.dkr.ecr.eu-central-1.amazonaws.com/neon:depot-${{needs.tag.outputs.build-tag}}
compute-tools-image:
runs-on: [ self-hosted, gen3, large ]
needs: [ tag ]

View File

@@ -31,4 +31,3 @@ jobs:
head: releases/${{ steps.date.outputs.date }}
base: release
title: Release ${{ steps.date.outputs.date }}
team_reviewers: release

2
.gitignore vendored
View File

@@ -18,3 +18,5 @@ test_output/
*.o
*.so
*.Po
tmp

133
Cargo.lock generated
View File

@@ -679,6 +679,25 @@ version = "0.3.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "37b2a672a2cb129a2e41c10b1224bb368f9f37a2b16b612598138befd7b37eb5"
[[package]]
name = "cbindgen"
version = "0.24.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "4b922faaf31122819ec80c4047cc684c6979a087366c069611e33649bf98e18d"
dependencies = [
"clap 3.2.23",
"heck",
"indexmap",
"log",
"proc-macro2",
"quote",
"serde",
"serde_json",
"syn",
"tempfile",
"toml",
]
[[package]]
name = "cc"
version = "1.0.79"
@@ -757,9 +776,12 @@ version = "3.2.23"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "71655c45cb9845d3270c9d6df84ebe72b4dad3c2ba3f7023ad47c144e4e473a5"
dependencies = [
"atty",
"bitflags",
"clap_lex 0.2.4",
"indexmap",
"strsim",
"termcolor",
"textwrap",
]
@@ -851,7 +873,6 @@ dependencies = [
"futures",
"hyper",
"notify",
"num_cpus",
"opentelemetry",
"postgres",
"regex",
@@ -914,7 +935,6 @@ dependencies = [
"once_cell",
"pageserver_api",
"postgres",
"postgres_backend",
"postgres_connection",
"regex",
"reqwest",
@@ -1016,6 +1036,20 @@ version = "1.1.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "6548a0ad5d2549e111e1f6a11a6c2e2d00ce6a3dafe22948d67c2b443f775e52"
[[package]]
name = "crossbeam"
version = "0.8.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "2801af0d36612ae591caa9568261fddce32ce6e08a7275ea334a06a4ad021a2c"
dependencies = [
"cfg-if",
"crossbeam-channel",
"crossbeam-deque",
"crossbeam-epoch",
"crossbeam-queue",
"crossbeam-utils",
]
[[package]]
name = "crossbeam-channel"
version = "0.5.6"
@@ -1050,6 +1084,16 @@ dependencies = [
"scopeguard",
]
[[package]]
name = "crossbeam-queue"
version = "0.3.8"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d1cfb3ea8a53f37c40dea2c7bedcbd88bdfae54f5e2175d6ecaff1c988353add"
dependencies = [
"cfg-if",
"crossbeam-utils",
]
[[package]]
name = "crossbeam-utils"
version = "0.8.14"
@@ -2456,7 +2500,6 @@ dependencies = [
"postgres",
"postgres-protocol",
"postgres-types",
"postgres_backend",
"postgres_connection",
"postgres_ffi",
"pq_proto",
@@ -2679,28 +2722,6 @@ dependencies = [
"postgres-protocol",
]
[[package]]
name = "postgres_backend"
version = "0.1.0"
dependencies = [
"anyhow",
"async-trait",
"bytes",
"futures",
"once_cell",
"pq_proto",
"rustls",
"rustls-pemfile",
"serde",
"thiserror",
"tokio",
"tokio-postgres",
"tokio-postgres-rustls",
"tokio-rustls",
"tracing",
"workspace_hack",
]
[[package]]
name = "postgres_connection"
version = "0.1.0"
@@ -2748,7 +2769,7 @@ checksum = "5b40af805b3121feab8a3c29f04d8ad262fa8e0561883e7653e024ae4479e6de"
name = "pq_proto"
version = "0.1.0"
dependencies = [
"byteorder",
"anyhow",
"bytes",
"pin-project-lite",
"postgres-protocol",
@@ -2923,7 +2944,6 @@ dependencies = [
"opentelemetry",
"parking_lot",
"pin-project-lite",
"postgres_backend",
"pq_proto",
"prometheus",
"rand",
@@ -3303,6 +3323,15 @@ dependencies = [
"base64 0.21.0",
]
[[package]]
name = "rustls-split"
version = "0.3.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "78802c9612b4689d207acff746f38132ca1b12dadb55d471aa5f10fd580f47d3"
dependencies = [
"rustls",
]
[[package]]
name = "rustversion"
version = "1.0.11"
@@ -3328,22 +3357,25 @@ dependencies = [
"clap 4.1.4",
"const_format",
"crc32c",
"crossbeam",
"fs2",
"git-version",
"hex",
"humantime",
"hyper",
"metrics",
"nix",
"once_cell",
"parking_lot",
"postgres",
"postgres-protocol",
"postgres_backend",
"postgres_ffi",
"pq_proto",
"rand",
"regex",
"remote_storage",
"safekeeper_api",
"scopeguard",
"serde",
"serde_json",
"serde_with",
@@ -4524,6 +4556,7 @@ dependencies = [
"bytes",
"criterion",
"futures",
"git-version",
"heapless",
"hex",
"hex-literal",
@@ -4532,9 +4565,12 @@ dependencies = [
"metrics",
"nix",
"once_cell",
"pin-project-lite",
"pq_proto",
"rand",
"routerify",
"rustls",
"rustls-pemfile",
"rustls-split",
"sentry",
"serde",
"serde_json",
@@ -4545,6 +4581,7 @@ dependencies = [
"tempfile",
"thiserror",
"tokio",
"tokio-rustls",
"tracing",
"tracing-subscriber",
"url",
@@ -4600,6 +4637,38 @@ dependencies = [
"winapi-util",
]
[[package]]
name = "walproposer"
version = "0.1.0"
dependencies = [
"anyhow",
"atty",
"bindgen",
"byteorder",
"bytes",
"cbindgen",
"crc32c",
"env_logger",
"hex",
"hyper",
"libc",
"log",
"memoffset 0.8.0",
"once_cell",
"postgres",
"postgres_ffi",
"rand",
"regex",
"safekeeper",
"scopeguard",
"serde",
"thiserror",
"tracing",
"tracing-subscriber",
"utils",
"workspace_hack",
]
[[package]]
name = "want"
version = "0.3.0"
@@ -4848,19 +4917,14 @@ name = "workspace_hack"
version = "0.1.0"
dependencies = [
"anyhow",
"byteorder",
"bytes",
"chrono",
"clap 4.1.4",
"crossbeam-utils",
"digest",
"either",
"fail",
"futures",
"futures-channel",
"futures-core",
"futures-executor",
"futures-sink",
"futures-util",
"hashbrown 0.12.3",
"indexmap",
@@ -4885,7 +4949,6 @@ dependencies = [
"socket2",
"syn",
"tokio",
"tokio-rustls",
"tokio-util",
"tonic",
"tower",

View File

@@ -64,7 +64,6 @@ md5 = "0.7.0"
memoffset = "0.8"
nix = "0.26"
notify = "5.0.0"
num_cpus = "1.15"
num-traits = "0.2.15"
once_cell = "1.13"
opentelemetry = "0.18.0"
@@ -134,16 +133,17 @@ heapless = { default-features=false, features=[], git = "https://github.com/japa
consumption_metrics = { version = "0.1", path = "./libs/consumption_metrics/" }
metrics = { version = "0.1", path = "./libs/metrics/" }
pageserver_api = { version = "0.1", path = "./libs/pageserver_api/" }
postgres_backend = { version = "0.1", path = "./libs/postgres_backend/" }
postgres_connection = { version = "0.1", path = "./libs/postgres_connection/" }
postgres_ffi = { version = "0.1", path = "./libs/postgres_ffi/" }
pq_proto = { version = "0.1", path = "./libs/pq_proto/" }
remote_storage = { version = "0.1", path = "./libs/remote_storage/" }
safekeeper_api = { version = "0.1", path = "./libs/safekeeper_api" }
safekeeper = { path = "./safekeeper/" }
storage_broker = { version = "0.1", path = "./storage_broker/" } # Note: main broker code is inside the binary crate, so linking with the library shouldn't be heavy.
tenant_size_model = { version = "0.1", path = "./libs/tenant_size_model/" }
tracing-utils = { version = "0.1", path = "./libs/tracing-utils/" }
utils = { version = "0.1", path = "./libs/utils/" }
walproposer = { version = "0.1", path = "./libs/walproposer/" }
## Common library dependency
workspace_hack = { version = "0.1", path = "./workspace_hack/" }

View File

@@ -39,7 +39,7 @@ ARG CACHEPOT_BUCKET=neon-github-dev
COPY --from=pg-build /home/nonroot/pg_install/v14/include/postgresql/server pg_install/v14/include/postgresql/server
COPY --from=pg-build /home/nonroot/pg_install/v15/include/postgresql/server pg_install/v15/include/postgresql/server
COPY --chown=nonroot . .
COPY . .
# Show build caching stats to check if it was used in the end.
# Has to be the part of the same RUN since cachepot daemon is killed in the end of this RUN, losing the compilation stats.

View File

@@ -255,51 +255,6 @@ RUN wget https://github.com/theory/pgtap/archive/refs/tags/v1.2.0.tar.gz -O pgta
make -j $(getconf _NPROCESSORS_ONLN) install PG_CONFIG=/usr/local/pgsql/bin/pg_config && \
echo 'trusted = true' >> /usr/local/pgsql/share/extension/pgtap.control
#########################################################################################
#
# Layer "prefix-pg-build"
# compile Prefix extension
#
#########################################################################################
FROM build-deps AS prefix-pg-build
COPY --from=pg-build /usr/local/pgsql/ /usr/local/pgsql/
RUN wget https://github.com/dimitri/prefix/archive/refs/tags/v1.2.9.tar.gz -O prefix.tar.gz && \
mkdir prefix-src && cd prefix-src && tar xvzf ../prefix.tar.gz --strip-components=1 -C . && \
make -j $(getconf _NPROCESSORS_ONLN) PG_CONFIG=/usr/local/pgsql/bin/pg_config && \
make -j $(getconf _NPROCESSORS_ONLN) install PG_CONFIG=/usr/local/pgsql/bin/pg_config && \
echo 'trusted = true' >> /usr/local/pgsql/share/extension/prefix.control
#########################################################################################
#
# Layer "hll-pg-build"
# compile hll extension
#
#########################################################################################
FROM build-deps AS hll-pg-build
COPY --from=pg-build /usr/local/pgsql/ /usr/local/pgsql/
RUN wget https://github.com/citusdata/postgresql-hll/archive/refs/tags/v2.17.tar.gz -O hll.tar.gz && \
mkdir hll-src && cd hll-src && tar xvzf ../hll.tar.gz --strip-components=1 -C . && \
make -j $(getconf _NPROCESSORS_ONLN) PG_CONFIG=/usr/local/pgsql/bin/pg_config && \
make -j $(getconf _NPROCESSORS_ONLN) install PG_CONFIG=/usr/local/pgsql/bin/pg_config && \
echo 'trusted = true' >> /usr/local/pgsql/share/extension/hll.control
#########################################################################################
#
# Layer "plpgsql-check-pg-build"
# compile plpgsql_check extension
#
#########################################################################################
FROM build-deps AS plpgsql-check-pg-build
COPY --from=pg-build /usr/local/pgsql/ /usr/local/pgsql/
RUN wget https://github.com/okbob/plpgsql_check/archive/refs/tags/v2.3.2.tar.gz -O plpgsql_check.tar.gz && \
mkdir plpgsql_check-src && cd plpgsql_check-src && tar xvzf ../plpgsql_check.tar.gz --strip-components=1 -C . && \
make -j $(getconf _NPROCESSORS_ONLN) PG_CONFIG=/usr/local/pgsql/bin/pg_config USE_PGXS=1 && \
make -j $(getconf _NPROCESSORS_ONLN) install PG_CONFIG=/usr/local/pgsql/bin/pg_config USE_PGXS=1 && \
echo 'trusted = true' >> /usr/local/pgsql/share/extension/plpgsql_check.control
#########################################################################################
#
# Layer "rust extensions"
@@ -323,7 +278,7 @@ RUN curl -sSO https://static.rust-lang.org/rustup/dist/$(uname -m)-unknown-linux
chmod +x rustup-init && \
./rustup-init -y --no-modify-path --profile minimal --default-toolchain stable && \
rm rustup-init && \
cargo install --locked --version 0.7.3 cargo-pgx && \
cargo install --git https://github.com/vadim2404/pgx --branch neon_abi_v0.6.1 --locked cargo-pgx && \
/bin/bash -c 'cargo pgx init --pg${PG_VERSION:1}=/usr/local/pgsql/bin/pg_config'
USER root
@@ -337,11 +292,11 @@ USER root
FROM rust-extensions-build AS pg-jsonschema-pg-build
# there is no release tag yet, but we need it due to the superuser fix in the control file
RUN wget https://github.com/supabase/pg_jsonschema/archive/caeab60d70b2fd3ae421ec66466a3abbb37b7ee6.tar.gz -O pg_jsonschema.tar.gz && \
mkdir pg_jsonschema-src && cd pg_jsonschema-src && tar xvzf ../pg_jsonschema.tar.gz --strip-components=1 -C . && \
sed -i 's/pgx = "0.7.1"/pgx = { version = "0.7.3", features = [ "unsafe-postgres" ] }/g' Cargo.toml && \
RUN git clone --depth=1 --single-branch --branch neon_abi_v0.1.4 https://github.com/vadim2404/pg_jsonschema/ && \
cd pg_jsonschema && \
cargo pgx install --release && \
# it's needed to enable extension because it uses untrusted C language
sed -i 's/superuser = false/superuser = true/g' /usr/local/pgsql/share/extension/pg_jsonschema.control && \
echo "trusted = true" >> /usr/local/pgsql/share/extension/pg_jsonschema.control
#########################################################################################
@@ -353,32 +308,13 @@ RUN wget https://github.com/supabase/pg_jsonschema/archive/caeab60d70b2fd3ae421e
FROM rust-extensions-build AS pg-graphql-pg-build
# Currently pgx version bump to >= 0.7.2 causes "call to unsafe function" compliation errors in
# pgx-contrib-spiext. There is a branch that removes that dependency, so use it. It is on the
# same 1.1 version we've used before.
RUN git clone -b remove-pgx-contrib-spiext --single-branch https://github.com/yrashk/pg_graphql && \
cd pg_graphql && \
sed -i 's/pgx = "~0.7.1"/pgx = { version = "0.7.3", features = [ "unsafe-postgres" ] }/g' Cargo.toml && \
sed -i 's/pgx-tests = "~0.7.1"/pgx-tests = "0.7.3"/g' Cargo.toml && \
RUN git clone --depth=1 --single-branch --branch neon_abi_v1.1.0 https://github.com/vadim2404/pg_graphql && \
cd pg_graphql && \
cargo pgx install --release && \
# it's needed to enable extension because it uses untrusted C language
sed -i 's/superuser = false/superuser = true/g' /usr/local/pgsql/share/extension/pg_graphql.control && \
echo "trusted = true" >> /usr/local/pgsql/share/extension/pg_graphql.control
#########################################################################################
#
# Layer "pg-tiktoken-build"
# Compile "pg_tiktoken" extension
#
#########################################################################################
FROM rust-extensions-build AS pg-tiktoken-pg-build
RUN git clone --depth=1 --single-branch https://github.com/kelvich/pg_tiktoken && \
cd pg_tiktoken && \
cargo pgx install --release && \
echo "trusted = true" >> /usr/local/pgsql/share/extension/pg_tiktoken.control
#########################################################################################
#
# Layer "neon-pg-ext-build"
@@ -396,23 +332,15 @@ COPY --from=vector-pg-build /usr/local/pgsql/ /usr/local/pgsql/
COPY --from=pgjwt-pg-build /usr/local/pgsql/ /usr/local/pgsql/
COPY --from=pg-jsonschema-pg-build /usr/local/pgsql/ /usr/local/pgsql/
COPY --from=pg-graphql-pg-build /usr/local/pgsql/ /usr/local/pgsql/
COPY --from=pg-tiktoken-pg-build /usr/local/pgsql/ /usr/local/pgsql/
COPY --from=hypopg-pg-build /usr/local/pgsql/ /usr/local/pgsql/
COPY --from=pg-hashids-pg-build /usr/local/pgsql/ /usr/local/pgsql/
COPY --from=rum-pg-build /usr/local/pgsql/ /usr/local/pgsql/
COPY --from=pgtap-pg-build /usr/local/pgsql/ /usr/local/pgsql/
COPY --from=prefix-pg-build /usr/local/pgsql/ /usr/local/pgsql/
COPY --from=hll-pg-build /usr/local/pgsql/ /usr/local/pgsql/
COPY --from=plpgsql-check-pg-build /usr/local/pgsql/ /usr/local/pgsql/
COPY pgxn/ pgxn/
RUN make -j $(getconf _NPROCESSORS_ONLN) \
PG_CONFIG=/usr/local/pgsql/bin/pg_config \
-C pgxn/neon \
-s install && \
make -j $(getconf _NPROCESSORS_ONLN) \
PG_CONFIG=/usr/local/pgsql/bin/pg_config \
-C pgxn/neon_utils \
-s install
#########################################################################################
@@ -467,7 +395,7 @@ COPY --from=compute-tools --chown=postgres /home/nonroot/target/release-line-deb
# Install:
# libreadline8 for psql
# libicu67, locales for collations (including ICU and plpgsql_check)
# libicu67, locales for collations (including ICU)
# libossp-uuid16 for extension ossp-uuid
# libgeos, libgdal, libsfcgal1, libproj and libprotobuf-c1 for PostGIS
# libxml2, libxslt1.1 for xml2

View File

@@ -1,70 +1,25 @@
# Note: this file *mostly* just builds on Dockerfile.compute-node
ARG SRC_IMAGE
ARG VM_INFORMANT_VERSION=v0.1.14
# on libcgroup update, make sure to check bootstrap.sh for changes
ARG LIBCGROUP_VERSION=v2.0.3
ARG VM_INFORMANT_VERSION=v0.1.6
# Pull VM informant, to copy from later
# Pull VM informant and set up inittab
FROM neondatabase/vm-informant:$VM_INFORMANT_VERSION as informant
# Build cgroup-tools
#
# At time of writing (2023-03-14), debian bullseye has a version of cgroup-tools (technically
# libcgroup) that doesn't support cgroup v2 (version 0.41-11). Unfortunately, the vm-informant
# requires cgroup v2, so we'll build cgroup-tools ourselves.
FROM debian:bullseye-slim as libcgroup-builder
ARG LIBCGROUP_VERSION
RUN set -exu \
&& apt update \
&& apt install --no-install-recommends -y \
git \
ca-certificates \
automake \
cmake \
make \
gcc \
byacc \
flex \
libtool \
libpam0g-dev \
&& git clone --depth 1 -b $LIBCGROUP_VERSION https://github.com/libcgroup/libcgroup \
&& INSTALL_DIR="/libcgroup-install" \
&& mkdir -p "$INSTALL_DIR/bin" "$INSTALL_DIR/include" \
&& cd libcgroup \
# extracted from bootstrap.sh, with modified flags:
&& (test -d m4 || mkdir m4) \
&& autoreconf -fi \
&& rm -rf autom4te.cache \
&& CFLAGS="-O3" ./configure --prefix="$INSTALL_DIR" --sysconfdir=/etc --localstatedir=/var --enable-opaque-hierarchy="name=systemd" \
# actually build the thing...
&& make install
# Combine, starting from non-VM compute node image.
FROM $SRC_IMAGE as base
# Temporarily set user back to root so we can run adduser, set inittab
USER root
RUN adduser vm-informant --disabled-password --no-create-home
RUN set -e \
&& rm -f /etc/inittab \
&& touch /etc/inittab
RUN set -e \
&& echo "::sysinit:cgconfigparser -l /etc/cgconfig.conf -s 1664" >> /etc/inittab \
&& CONNSTR="dbname=neondb user=cloud_admin sslmode=disable" \
&& ARGS="--auto-restart --cgroup=neon-postgres --pgconnstr=\"$CONNSTR\"" \
&& echo "::respawn:su vm-informant -c '/usr/local/bin/vm-informant $ARGS'" >> /etc/inittab
&& echo "::respawn:su vm-informant -c '/usr/local/bin/vm-informant --auto-restart'" >> /etc/inittab
# Combine, starting from non-VM compute node image.
FROM $SRC_IMAGE as base
# Temporarily set user back to root so we can run adduser
USER root
RUN adduser vm-informant --disabled-password --no-create-home
USER postgres
ADD vm-cgconfig.conf /etc/cgconfig.conf
COPY --from=informant /etc/inittab /etc/inittab
COPY --from=informant /usr/bin/vm-informant /usr/local/bin/vm-informant
COPY --from=libcgroup-builder /libcgroup-install/bin/* /usr/bin/
COPY --from=libcgroup-builder /libcgroup-install/lib/* /usr/lib/
COPY --from=libcgroup-builder /libcgroup-install/sbin/* /usr/sbin/
ENTRYPOINT ["/usr/sbin/cgexec", "-g", "*:neon-postgres", "/usr/local/bin/compute_ctl"]

View File

@@ -39,6 +39,8 @@ endif
# been no changes to the files. Changing the mtime triggers an
# unnecessary rebuild of 'postgres_ffi'.
PG_CONFIGURE_OPTS += INSTALL='$(ROOT_PROJECT_DIR)/scripts/ninstall.sh -C'
PG_CONFIGURE_OPTS += CC=clang
PG_CONFIGURE_OPTS += CCX=clang++
# Choose whether we should be silent or verbose
CARGO_BUILD_FLAGS += --$(if $(filter s,$(MAKEFLAGS)),quiet,verbose)
@@ -133,11 +135,12 @@ neon-pg-ext-%: postgres-%
$(MAKE) PG_CONFIG=$(POSTGRES_INSTALL_DIR)/$*/bin/pg_config CFLAGS='$(PG_CFLAGS) $(COPT)' \
-C $(POSTGRES_INSTALL_DIR)/build/neon-test-utils-$* \
-f $(ROOT_PROJECT_DIR)/pgxn/neon_test_utils/Makefile install
+@echo "Compiling neon_utils $*"
mkdir -p $(POSTGRES_INSTALL_DIR)/build/neon-utils-$*
$(MAKE) PG_CONFIG=$(POSTGRES_INSTALL_DIR)/$*/bin/pg_config CFLAGS='$(PG_CFLAGS) $(COPT)' \
-C $(POSTGRES_INSTALL_DIR)/build/neon-utils-$* \
-f $(ROOT_PROJECT_DIR)/pgxn/neon_utils/Makefile install
.PHONY:
neon-pg-ext-walproposer:
$(MAKE) PG_CONFIG=$(POSTGRES_INSTALL_DIR)/v15/bin/pg_config CFLAGS='$(PG_CFLAGS) $(COPT)' \
-C $(POSTGRES_INSTALL_DIR)/build/neon-v15 \
-f $(ROOT_PROJECT_DIR)/pgxn/neon/Makefile install
.PHONY: neon-pg-ext-clean-%
neon-pg-ext-clean-%:
@@ -150,9 +153,6 @@ neon-pg-ext-clean-%:
$(MAKE) PG_CONFIG=$(POSTGRES_INSTALL_DIR)/$*/bin/pg_config \
-C $(POSTGRES_INSTALL_DIR)/build/neon-test-utils-$* \
-f $(ROOT_PROJECT_DIR)/pgxn/neon_test_utils/Makefile clean
$(MAKE) PG_CONFIG=$(POSTGRES_INSTALL_DIR)/$*/bin/pg_config \
-C $(POSTGRES_INSTALL_DIR)/build/neon-utils-$* \
-f $(ROOT_PROJECT_DIR)/pgxn/neon_utils/Makefile clean
.PHONY: neon-pg-ext
neon-pg-ext: \

View File

@@ -46,14 +46,11 @@ postgresql-libs cmake postgresql protobuf
curl --proto '=https' --tlsv1.2 -sSf https://sh.rustup.rs | sh
```
#### Installing dependencies on macOS (12.3.1)
#### Installing dependencies on OSX (12.3.1)
1. Install XCode and dependencies
```
xcode-select --install
brew install protobuf openssl flex bison
# add openssl to PATH, required for ed25519 keys generation in neon_local
echo 'export PATH="$(brew --prefix openssl)/bin:$PATH"' >> ~/.zshrc
```
2. [Install Rust](https://www.rust-lang.org/tools/install)

View File

@@ -11,7 +11,6 @@ clap.workspace = true
futures.workspace = true
hyper = { workspace = true, features = ["full"] }
notify.workspace = true
num_cpus.workspace = true
opentelemetry.workspace = true
postgres.workspace = true
regex.workspace = true

View File

@@ -133,7 +133,6 @@ fn main() -> Result<()> {
.settings
.find("neon.pageserver_connstring")
.expect("pageserver connstr should be provided");
let storage_auth_token = spec.storage_auth_token.clone();
let tenant = spec
.cluster
.settings
@@ -154,7 +153,6 @@ fn main() -> Result<()> {
tenant,
timeline,
pageserver_connstr,
storage_auth_token,
metrics: ComputeMetrics::default(),
state: RwLock::new(ComputeState::new()),
};

View File

@@ -18,7 +18,6 @@ use std::fs;
use std::os::unix::fs::PermissionsExt;
use std::path::Path;
use std::process::{Command, Stdio};
use std::str::FromStr;
use std::sync::atomic::{AtomicU64, Ordering};
use std::sync::RwLock;
@@ -26,7 +25,6 @@ use anyhow::{Context, Result};
use chrono::{DateTime, Utc};
use postgres::{Client, NoTls};
use serde::{Serialize, Serializer};
use tokio_postgres;
use tracing::{info, instrument, warn};
use crate::checker::create_writability_check_data;
@@ -45,7 +43,6 @@ pub struct ComputeNode {
pub tenant: String,
pub timeline: String,
pub pageserver_connstr: String,
pub storage_auth_token: Option<String>,
pub metrics: ComputeMetrics,
/// Volatile part of the `ComputeNode` so should be used under `RwLock`
/// to allow HTTP API server to serve status requests, while configuration
@@ -128,18 +125,7 @@ impl ComputeNode {
fn get_basebackup(&self, lsn: &str) -> Result<()> {
let start_time = Utc::now();
let mut config = postgres::Config::from_str(&self.pageserver_connstr)?;
// Use the storage auth token from the config file, if given.
// Note: this overrides any password set in the connection string.
if let Some(storage_auth_token) = &self.storage_auth_token {
info!("Got storage auth token from spec file");
config.password(storage_auth_token);
} else {
info!("Storage auth token not set");
}
let mut client = config.connect(NoTls)?;
let mut client = Client::connect(&self.pageserver_connstr, NoTls)?;
let basebackup_cmd = match lsn {
"0/0" => format!("basebackup {} {}", &self.tenant, &self.timeline), // First start of the compute
_ => format!("basebackup {} {} {}", &self.tenant, &self.timeline, lsn),
@@ -176,11 +162,6 @@ impl ComputeNode {
let sync_handle = Command::new(&self.pgbin)
.args(["--sync-safekeepers"])
.env("PGDATA", &self.pgdata) // we cannot use -D in this mode
.envs(if let Some(storage_auth_token) = &self.storage_auth_token {
vec![("NEON_AUTH_TOKEN", storage_auth_token)]
} else {
vec![]
})
.stdout(Stdio::piped())
.spawn()
.expect("postgres --sync-safekeepers failed to start");
@@ -258,11 +239,6 @@ impl ComputeNode {
// Run postgres as a child process.
let mut pg = Command::new(&self.pgbin)
.args(["-D", &self.pgdata])
.envs(if let Some(storage_auth_token) = &self.storage_auth_token {
vec![("NEON_AUTH_TOKEN", storage_auth_token)]
} else {
vec![]
})
.spawn()
.expect("cannot start postgres process");
@@ -308,7 +284,6 @@ impl ComputeNode {
handle_role_deletions(self, &mut client)?;
handle_grants(self, &mut client)?;
create_writability_check_data(&mut client)?;
handle_extensions(&self.spec, &mut client)?;
// 'Close' connection
drop(client);
@@ -425,43 +400,4 @@ impl ComputeNode {
Ok(())
}
/// Select `pg_stat_statements` data and return it as a stringified JSON
pub async fn collect_insights(&self) -> String {
let mut result_rows: Vec<String> = Vec::new();
let connect_result = tokio_postgres::connect(self.connstr.as_str(), NoTls).await;
let (client, connection) = connect_result.unwrap();
tokio::spawn(async move {
if let Err(e) = connection.await {
eprintln!("connection error: {}", e);
}
});
let result = client
.simple_query(
"SELECT
row_to_json(pg_stat_statements)
FROM
pg_stat_statements
WHERE
userid != 'cloud_admin'::regrole::oid
ORDER BY
(mean_exec_time + mean_plan_time) DESC
LIMIT 100",
)
.await;
if let Ok(raw_rows) = result {
for message in raw_rows.iter() {
if let postgres::SimpleQueryMessage::Row(row) = message {
if let Some(json) = row.get(0) {
result_rows.push(json.to_string());
}
}
}
format!("{{\"pg_stat_statements\": [{}]}}", result_rows.join(","))
} else {
"{{\"pg_stat_statements\": []}}".to_string()
}
}
}

View File

@@ -7,7 +7,6 @@ use crate::compute::ComputeNode;
use anyhow::Result;
use hyper::service::{make_service_fn, service_fn};
use hyper::{Body, Method, Request, Response, Server, StatusCode};
use num_cpus;
use serde_json;
use tracing::{error, info};
use tracing_utils::http::OtelName;
@@ -34,13 +33,6 @@ async fn routes(req: Request<Body>, compute: &Arc<ComputeNode>) -> Response<Body
Response::new(Body::from(serde_json::to_string(&compute.metrics).unwrap()))
}
// Collect Postgres current usage insights
(&Method::GET, "/insights") => {
info!("serving /insights GET request");
let insights = compute.collect_insights().await;
Response::new(Body::from(insights))
}
(&Method::POST, "/check_writability") => {
info!("serving /check_writability POST request");
let res = crate::checker::check_writability(compute).await;
@@ -50,17 +42,6 @@ async fn routes(req: Request<Body>, compute: &Arc<ComputeNode>) -> Response<Body
}
}
(&Method::GET, "/info") => {
let num_cpus = num_cpus::get_physical();
info!("serving /info GET request. num_cpus: {}", num_cpus);
Response::new(Body::from(
serde_json::json!({
"num_cpus": num_cpus,
})
.to_string(),
))
}
// Return the `404 Not Found` for any other routes.
_ => {
let mut not_found = Response::new(Body::from("404 Not Found"));

View File

@@ -10,12 +10,12 @@ paths:
/status:
get:
tags:
- Info
- "info"
summary: Get compute node internal status
description: ""
operationId: getComputeStatus
responses:
200:
"200":
description: ComputeState
content:
application/json:
@@ -25,58 +25,27 @@ paths:
/metrics.json:
get:
tags:
- Info
- "info"
summary: Get compute node startup metrics in JSON format
description: ""
operationId: getComputeMetricsJSON
responses:
200:
"200":
description: ComputeMetrics
content:
application/json:
schema:
$ref: "#/components/schemas/ComputeMetrics"
/insights:
get:
tags:
- Info
summary: Get current compute insights in JSON format
description: |
Note, that this doesn't include any historical data
operationId: getComputeInsights
responses:
200:
description: Compute insights
content:
application/json:
schema:
$ref: "#/components/schemas/ComputeInsights"
/info:
get:
tags:
- "info"
summary: Get info about the compute Pod/VM
description: ""
operationId: getInfo
responses:
"200":
description: Info
content:
application/json:
schema:
$ref: "#/components/schemas/Info"
/check_writability:
post:
tags:
- Check
- "check"
summary: Check that we can write new data on this compute
description: ""
operationId: checkComputeWritability
responses:
200:
"200":
description: Check result
content:
text/plain:
@@ -111,15 +80,6 @@ components:
total_startup_ms:
type: integer
Info:
type: object
description: Information about VM/Pod
required:
- num_cpus
properties:
num_cpus:
type: integer
ComputeState:
type: object
required:
@@ -136,15 +96,6 @@ components:
type: string
description: Text of the error during compute startup, if any
ComputeInsights:
type: object
properties:
pg_stat_statements:
description: Contains raw output from pg_stat_statements in JSON format
type: array
items:
type: object
ComputeStatus:
type: string
enum:

View File

@@ -47,23 +47,12 @@ pub struct GenericOption {
/// declare a `trait` on it.
pub type GenericOptions = Option<Vec<GenericOption>>;
/// Escape a string for including it in a SQL literal
fn escape_literal(s: &str) -> String {
s.replace('\'', "''").replace('\\', "\\\\")
}
/// Escape a string so that it can be used in postgresql.conf.
/// Same as escape_literal, currently.
fn escape_conf_value(s: &str) -> String {
s.replace('\'', "''").replace('\\', "\\\\")
}
impl GenericOption {
/// Represent `GenericOption` as SQL statement parameter.
pub fn to_pg_option(&self) -> String {
if let Some(val) = &self.value {
match self.vartype.as_ref() {
"string" => format!("{} '{}'", self.name, escape_literal(val)),
"string" => format!("{} '{}'", self.name, val),
_ => format!("{} {}", self.name, val),
}
} else {
@@ -74,8 +63,6 @@ impl GenericOption {
/// Represent `GenericOption` as configuration option.
pub fn to_pg_setting(&self) -> String {
if let Some(val) = &self.value {
// TODO: check in the console DB that we don't have these settings
// set for any non-deleted project and drop this override.
let name = match self.name.as_str() {
"safekeepers" => "neon.safekeepers",
"wal_acceptor_reconnect" => "neon.safekeeper_reconnect_timeout",
@@ -84,7 +71,7 @@ impl GenericOption {
};
match self.vartype.as_ref() {
"string" => format!("{} = '{}'", name, escape_conf_value(val)),
"string" => format!("{} = '{}'", name, val),
_ => format!("{} = {}", name, val),
}
} else {
@@ -120,7 +107,6 @@ impl PgOptionsSerialize for GenericOptions {
.map(|op| op.to_pg_setting())
.collect::<Vec<String>>()
.join("\n")
+ "\n" // newline after last setting
} else {
"".to_string()
}

View File

@@ -24,8 +24,6 @@ pub struct ComputeSpec {
pub cluster: Cluster,
pub delta_operations: Option<Vec<DeltaOp>>,
pub storage_auth_token: Option<String>,
pub startup_tracing_context: Option<HashMap<String, String>>,
}
@@ -517,18 +515,3 @@ pub fn handle_grants(node: &ComputeNode, client: &mut Client) -> Result<()> {
Ok(())
}
/// Create required system extensions
#[instrument(skip_all)]
pub fn handle_extensions(spec: &ComputeSpec, client: &mut Client) -> Result<()> {
if let Some(libs) = spec.cluster.settings.find("shared_preload_libraries") {
if libs.contains("pg_stat_statements") {
// Create extension only if this compute really needs it
let query = "CREATE EXTENSION IF NOT EXISTS pg_stat_statements";
info!("creating system extensions with query: {}", query);
client.simple_query(query)?;
}
}
Ok(())
}

View File

@@ -178,11 +178,6 @@
"name": "neon.pageserver_connstring",
"value": "host=127.0.0.1 port=6400",
"vartype": "string"
},
{
"name": "test.escaping",
"value": "here's a backslash \\ and a quote ' and a double-quote \" hooray",
"vartype": "string"
}
]
},

View File

@@ -28,30 +28,7 @@ mod pg_helpers_tests {
assert_eq!(
spec.cluster.settings.as_pg_settings(),
r#"fsync = off
wal_level = replica
hot_standby = on
neon.safekeepers = '127.0.0.1:6502,127.0.0.1:6503,127.0.0.1:6501'
wal_log_hints = on
log_connections = on
shared_buffers = 32768
port = 55432
max_connections = 100
max_wal_senders = 10
listen_addresses = '0.0.0.0'
wal_sender_timeout = 0
password_encryption = md5
maintenance_work_mem = 65536
max_parallel_workers = 8
max_worker_processes = 8
neon.tenant_id = 'b0554b632bd4d547a63b86c3630317e8'
max_replication_slots = 10
neon.timeline_id = '2414a61ffc94e428f14b5758fe308e13'
shared_preload_libraries = 'neon'
synchronous_standby_names = 'walproposer'
neon.pageserver_connstring = 'host=127.0.0.1 port=6400'
test.escaping = 'here''s a backslash \\ and a quote '' and a double-quote " hooray'
"#
"fsync = off\nwal_level = replica\nhot_standby = on\nneon.safekeepers = '127.0.0.1:6502,127.0.0.1:6503,127.0.0.1:6501'\nwal_log_hints = on\nlog_connections = on\nshared_buffers = 32768\nport = 55432\nmax_connections = 100\nmax_wal_senders = 10\nlisten_addresses = '0.0.0.0'\nwal_sender_timeout = 0\npassword_encryption = md5\nmaintenance_work_mem = 65536\nmax_parallel_workers = 8\nmax_worker_processes = 8\nneon.tenant_id = 'b0554b632bd4d547a63b86c3630317e8'\nmax_replication_slots = 10\nneon.timeline_id = '2414a61ffc94e428f14b5758fe308e13'\nshared_preload_libraries = 'neon'\nsynchronous_standby_names = 'walproposer'\nneon.pageserver_connstring = 'host=127.0.0.1 port=6400'"
);
}

View File

@@ -24,7 +24,6 @@ url.workspace = true
# Note: Do not directly depend on pageserver or safekeeper; use pageserver_api or safekeeper_api
# instead, so that recompile times are better.
pageserver_api.workspace = true
postgres_backend.workspace = true
safekeeper_api.workspace = true
postgres_connection.workspace = true
storage_broker.workspace = true

View File

@@ -2,8 +2,7 @@
[pageserver]
listen_pg_addr = '127.0.0.1:64000'
listen_http_addr = '127.0.0.1:9898'
pg_auth_type = 'Trust'
http_auth_type = 'Trust'
auth_type = 'Trust'
[[safekeepers]]
id = 1

View File

@@ -3,8 +3,7 @@
[pageserver]
listen_pg_addr = '127.0.0.1:64000'
listen_http_addr = '127.0.0.1:9898'
pg_auth_type = 'Trust'
http_auth_type = 'Trust'
auth_type = 'Trust'
[[safekeepers]]
id = 1

View File

@@ -17,7 +17,6 @@ use pageserver_api::{
DEFAULT_HTTP_LISTEN_ADDR as DEFAULT_PAGESERVER_HTTP_ADDR,
DEFAULT_PG_LISTEN_ADDR as DEFAULT_PAGESERVER_PG_ADDR,
};
use postgres_backend::AuthType;
use safekeeper_api::{
DEFAULT_HTTP_LISTEN_PORT as DEFAULT_SAFEKEEPER_HTTP_PORT,
DEFAULT_PG_LISTEN_PORT as DEFAULT_SAFEKEEPER_PG_PORT,
@@ -31,6 +30,7 @@ use utils::{
auth::{Claims, Scope},
id::{NodeId, TenantId, TenantTimelineId, TimelineId},
lsn::Lsn,
postgres_backend::AuthType,
project_git_version,
};
@@ -53,15 +53,14 @@ listen_addr = '{DEFAULT_BROKER_ADDR}'
id = {DEFAULT_PAGESERVER_ID}
listen_pg_addr = '{DEFAULT_PAGESERVER_PG_ADDR}'
listen_http_addr = '{DEFAULT_PAGESERVER_HTTP_ADDR}'
pg_auth_type = '{trust_auth}'
http_auth_type = '{trust_auth}'
auth_type = '{pageserver_auth_type}'
[[safekeepers]]
id = {DEFAULT_SAFEKEEPER_ID}
pg_port = {DEFAULT_SAFEKEEPER_PG_PORT}
http_port = {DEFAULT_SAFEKEEPER_HTTP_PORT}
"#,
trust_auth = AuthType::Trust,
pageserver_auth_type = AuthType::Trust,
)
}
@@ -628,7 +627,7 @@ fn handle_pg(pg_match: &ArgMatches, env: &local_env::LocalEnv) -> Result<()> {
let node = cplane.nodes.get(&(tenant_id, node_name.to_string()));
let auth_token = if matches!(env.pageserver.pg_auth_type, AuthType::NeonJWT) {
let auth_token = if matches!(env.pageserver.auth_type, AuthType::NeonJWT) {
let claims = Claims::new(Some(tenant_id), Scope::Tenant);
Some(env.generate_auth_token(&claims)?)

View File

@@ -14,6 +14,7 @@ use anyhow::{Context, Result};
use utils::{
id::{TenantId, TimelineId},
lsn::Lsn,
postgres_backend::AuthType,
};
use crate::local_env::{LocalEnv, DEFAULT_PG_VERSION};
@@ -96,7 +97,7 @@ impl ComputeControlPlane {
});
node.create_pgdata()?;
node.setup_pg_conf()?;
node.setup_pg_conf(self.env.pageserver.auth_type)?;
self.nodes
.insert((tenant_id, node.name.clone()), Arc::clone(&node));
@@ -277,7 +278,7 @@ impl PostgresNode {
// Write postgresql.conf with default configuration
// and PG_VERSION file to the data directory of a new node.
fn setup_pg_conf(&self) -> Result<()> {
fn setup_pg_conf(&self, auth_type: AuthType) -> Result<()> {
let mut conf = PostgresConf::new();
conf.append("max_wal_senders", "10");
conf.append("wal_log_hints", "off");
@@ -301,12 +302,29 @@ impl PostgresNode {
let config = &self.pageserver.pg_connection_config;
let (host, port) = (config.host(), config.port());
// NOTE: avoid spaces in connection string, because it is less error prone if we forward it somewhere.
format!("postgresql://no_user@{host}:{port}")
// Set up authentication
//
// $NEON_AUTH_TOKEN will be replaced with value from environment
// variable during compute pg startup. It is done this way because
// otherwise user will be able to retrieve the value using SHOW
// command or pg_settings
let password = if let AuthType::NeonJWT = auth_type {
"$NEON_AUTH_TOKEN"
} else {
""
};
// NOTE avoiding spaces in connection string, because it is less error prone if we forward it somewhere.
// Also note that not all parameters are supported here. Because in compute we substitute $NEON_AUTH_TOKEN
// We parse this string and build it back with token from env var, and for simplicity rebuild
// uses only needed variables namely host, port, user, password.
format!("postgresql://no_user:{password}@{host}:{port}")
};
conf.append("shared_preload_libraries", "neon");
conf.append_line("");
conf.append("neon.pageserver_connstring", &pageserver_connstr);
if let AuthType::NeonJWT = auth_type {
conf.append("neon.safekeeper_token_env", "$NEON_AUTH_TOKEN");
}
conf.append("neon.tenant_id", &self.tenant_id.to_string());
conf.append("neon.timeline_id", &self.timeline_id.to_string());
if let Some(lsn) = self.lsn {
@@ -429,8 +447,6 @@ impl PostgresNode {
"DYLD_LIBRARY_PATH",
self.env.pg_lib_dir(self.pg_version)?.to_str().unwrap(),
);
// Pass authentication token used for the connections to pageserver and safekeepers
if let Some(token) = auth_token {
cmd.env("NEON_AUTH_TOKEN", token);
}

View File

@@ -5,7 +5,6 @@
use anyhow::{bail, ensure, Context};
use postgres_backend::AuthType;
use reqwest::Url;
use serde::{Deserialize, Serialize};
use serde_with::{serde_as, DisplayFromStr};
@@ -18,8 +17,9 @@ use std::net::SocketAddr;
use std::path::{Path, PathBuf};
use std::process::{Command, Stdio};
use utils::{
auth::{encode_from_key_file, Claims},
auth::{encode_from_key_file, Claims, Scope},
id::{NodeId, TenantId, TenantTimelineId, TimelineId},
postgres_backend::AuthType,
};
use crate::safekeeper::SafekeeperNode;
@@ -110,14 +110,15 @@ impl NeonBroker {
pub struct PageServerConf {
// node id
pub id: NodeId,
// Pageserver connection settings
pub listen_pg_addr: String,
pub listen_http_addr: String,
// auth type used for the PG and HTTP ports
pub pg_auth_type: AuthType,
pub http_auth_type: AuthType,
// used to determine which auth type is used
pub auth_type: AuthType,
// jwt auth token used for communication with pageserver
pub auth_token: String,
}
impl Default for PageServerConf {
@@ -126,8 +127,8 @@ impl Default for PageServerConf {
id: NodeId(0),
listen_pg_addr: String::new(),
listen_http_addr: String::new(),
pg_auth_type: AuthType::Trust,
http_auth_type: AuthType::Trust,
auth_type: AuthType::Trust,
auth_token: String::new(),
}
}
}
@@ -400,33 +401,48 @@ impl LocalEnv {
fs::create_dir(base_path)?;
// Generate keypair for JWT.
//
// The keypair is only needed if authentication is enabled in any of the
// components. For convenience, we generate the keypair even if authentication
// is not enabled, so that you can easily enable it after the initialization
// step. However, if the key generation fails, we treat it as non-fatal if
// authentication was not enabled.
// generate keys for jwt
// openssl genrsa -out private_key.pem 2048
let private_key_path;
if self.private_key_path == PathBuf::new() {
match generate_auth_keys(
base_path.join("auth_private_key.pem").as_path(),
base_path.join("auth_public_key.pem").as_path(),
) {
Ok(()) => {
self.private_key_path = PathBuf::from("auth_private_key.pem");
}
Err(e) => {
if !self.auth_keys_needed() {
eprintln!("Could not generate keypair for JWT authentication: {e}");
eprintln!("Continuing anyway because authentication was not enabled");
self.private_key_path = PathBuf::from("auth_private_key.pem");
} else {
return Err(e);
}
}
private_key_path = base_path.join("auth_private_key.pem");
let keygen_output = Command::new("openssl")
.arg("genrsa")
.args(["-out", private_key_path.to_str().unwrap()])
.arg("2048")
.stdout(Stdio::null())
.output()
.context("failed to generate auth private key")?;
if !keygen_output.status.success() {
bail!(
"openssl failed: '{}'",
String::from_utf8_lossy(&keygen_output.stderr)
);
}
self.private_key_path = PathBuf::from("auth_private_key.pem");
let public_key_path = base_path.join("auth_public_key.pem");
// openssl rsa -in private_key.pem -pubout -outform PEM -out public_key.pem
let keygen_output = Command::new("openssl")
.arg("rsa")
.args(["-in", private_key_path.to_str().unwrap()])
.arg("-pubout")
.args(["-outform", "PEM"])
.args(["-out", public_key_path.to_str().unwrap()])
.stdout(Stdio::null())
.output()
.context("failed to generate auth private key")?;
if !keygen_output.status.success() {
bail!(
"openssl failed: '{}'",
String::from_utf8_lossy(&keygen_output.stderr)
);
}
}
self.pageserver.auth_token =
self.generate_auth_token(&Claims::new(None, Scope::PageServerApi))?;
fs::create_dir_all(self.pg_data_dirs_path())?;
for safekeeper in &self.safekeepers {
@@ -435,12 +451,6 @@ impl LocalEnv {
self.persist_config(base_path)
}
fn auth_keys_needed(&self) -> bool {
self.pageserver.pg_auth_type == AuthType::NeonJWT
|| self.pageserver.http_auth_type == AuthType::NeonJWT
|| self.safekeepers.iter().any(|sk| sk.auth_enabled)
}
}
fn base_path() -> PathBuf {
@@ -450,43 +460,6 @@ fn base_path() -> PathBuf {
}
}
/// Generate a public/private key pair for JWT authentication
fn generate_auth_keys(private_key_path: &Path, public_key_path: &Path) -> anyhow::Result<()> {
// Generate the key pair
//
// openssl genpkey -algorithm ed25519 -out auth_private_key.pem
let keygen_output = Command::new("openssl")
.arg("genpkey")
.args(["-algorithm", "ed25519"])
.args(["-out", private_key_path.to_str().unwrap()])
.stdout(Stdio::null())
.output()
.context("failed to generate auth private key")?;
if !keygen_output.status.success() {
bail!(
"openssl failed: '{}'",
String::from_utf8_lossy(&keygen_output.stderr)
);
}
// Extract the public key from the private key file
//
// openssl pkey -in auth_private_key.pem -pubout -out auth_public_key.pem
let keygen_output = Command::new("openssl")
.arg("pkey")
.args(["-in", private_key_path.to_str().unwrap()])
.arg("-pubout")
.args(["-out", public_key_path.to_str().unwrap()])
.output()
.context("failed to extract public key from private key")?;
if !keygen_output.status.success() {
bail!(
"openssl failed: '{}'",
String::from_utf8_lossy(&keygen_output.stderr)
);
}
Ok(())
}
#[cfg(test)]
mod tests {
use super::*;

View File

@@ -11,7 +11,6 @@ use anyhow::{bail, Context};
use pageserver_api::models::{
TenantConfigRequest, TenantCreateRequest, TenantInfo, TimelineCreateRequest, TimelineInfo,
};
use postgres_backend::AuthType;
use postgres_connection::{parse_host_port, PgConnectionConfig};
use reqwest::blocking::{Client, RequestBuilder, Response};
use reqwest::{IntoUrl, Method};
@@ -21,6 +20,7 @@ use utils::{
http::error::HttpErrorBody,
id::{TenantId, TimelineId},
lsn::Lsn,
postgres_backend::AuthType,
};
use crate::{background_process, local_env::LocalEnv};
@@ -82,8 +82,15 @@ impl PageServerNode {
let (host, port) = parse_host_port(&env.pageserver.listen_pg_addr)
.expect("Unable to parse listen_pg_addr");
let port = port.unwrap_or(5432);
let password = if env.pageserver.auth_type == AuthType::NeonJWT {
Some(env.pageserver.auth_token.clone())
} else {
None
};
Self {
pg_connection_config: PgConnectionConfig::new_host_port(host, port),
pg_connection_config: PgConnectionConfig::new_host_port(host, port)
.set_password(password),
env: env.clone(),
http_client: Client::new(),
http_base_url: format!("http://{}/v1", env.pageserver.listen_http_addr),
@@ -99,32 +106,25 @@ impl PageServerNode {
self.env.pg_distrib_dir_raw().display()
);
let http_auth_type_param =
format!("http_auth_type='{}'", self.env.pageserver.http_auth_type);
let authg_type_param = format!("auth_type='{}'", self.env.pageserver.auth_type);
let listen_http_addr_param = format!(
"listen_http_addr='{}'",
self.env.pageserver.listen_http_addr
);
let pg_auth_type_param = format!("pg_auth_type='{}'", self.env.pageserver.pg_auth_type);
let listen_pg_addr_param =
format!("listen_pg_addr='{}'", self.env.pageserver.listen_pg_addr);
let broker_endpoint_param = format!("broker_endpoint='{}'", self.env.broker.client_url());
let mut overrides = vec![
id,
pg_distrib_dir_param,
http_auth_type_param,
pg_auth_type_param,
authg_type_param,
listen_http_addr_param,
listen_pg_addr_param,
broker_endpoint_param,
];
if self.env.pageserver.http_auth_type != AuthType::Trust
|| self.env.pageserver.pg_auth_type != AuthType::Trust
{
if self.env.pageserver.auth_type != AuthType::Trust {
overrides.push("auth_validation_public_key_path='auth_public_key.pem'".to_owned());
}
overrides
@@ -247,10 +247,7 @@ impl PageServerNode {
}
fn pageserver_env_variables(&self) -> anyhow::Result<Vec<(String, String)>> {
// FIXME: why is this tied to pageserver's auth type? Whether or not the safekeeper
// needs a token, and how to generate that token, seems independent to whether
// the pageserver requires a token in incoming requests.
Ok(if self.env.pageserver.http_auth_type != AuthType::Trust {
Ok(if self.env.pageserver.auth_type != AuthType::Trust {
// Generate a token to connect from the pageserver to a safekeeper
let token = self
.env
@@ -273,30 +270,27 @@ impl PageServerNode {
background_process::stop_process(immediate, "pageserver", &self.pid_file())
}
pub fn page_server_psql_client(&self) -> anyhow::Result<postgres::Client> {
let mut config = self.pg_connection_config.clone();
if self.env.pageserver.pg_auth_type == AuthType::NeonJWT {
let token = self
.env
.generate_auth_token(&Claims::new(None, Scope::PageServerApi))?;
config = config.set_password(Some(token));
}
Ok(config.connect_no_tls()?)
pub fn page_server_psql(&self, sql: &str) -> Vec<postgres::SimpleQueryMessage> {
let mut client = self.pg_connection_config.connect_no_tls().unwrap();
println!("Pageserver query: '{sql}'");
client.simple_query(sql).unwrap()
}
fn http_request<U: IntoUrl>(&self, method: Method, url: U) -> anyhow::Result<RequestBuilder> {
pub fn page_server_psql_client(&self) -> result::Result<postgres::Client, postgres::Error> {
self.pg_connection_config.connect_no_tls()
}
fn http_request<U: IntoUrl>(&self, method: Method, url: U) -> RequestBuilder {
let mut builder = self.http_client.request(method, url);
if self.env.pageserver.http_auth_type == AuthType::NeonJWT {
let token = self
.env
.generate_auth_token(&Claims::new(None, Scope::PageServerApi))?;
builder = builder.bearer_auth(token)
if self.env.pageserver.auth_type == AuthType::NeonJWT {
builder = builder.bearer_auth(&self.env.pageserver.auth_token)
}
Ok(builder)
builder
}
pub fn check_status(&self) -> Result<()> {
self.http_request(Method::GET, format!("{}/status", self.http_base_url))?
self.http_request(Method::GET, format!("{}/status", self.http_base_url))
.send()?
.error_from_body()?;
Ok(())
@@ -304,7 +298,7 @@ impl PageServerNode {
pub fn tenant_list(&self) -> Result<Vec<TenantInfo>> {
Ok(self
.http_request(Method::GET, format!("{}/tenant", self.http_base_url))?
.http_request(Method::GET, format!("{}/tenant", self.http_base_url))
.send()?
.error_from_body()?
.json()?)
@@ -358,16 +352,11 @@ impl PageServerNode {
.map(|x| x.parse::<bool>())
.transpose()
.context("Failed to parse 'trace_read_requests' as bool")?,
eviction_policy: settings
.get("eviction_policy")
.map(|x| serde_json::from_str(x))
.transpose()
.context("Failed to parse 'eviction_policy' json")?,
};
if !settings.is_empty() {
bail!("Unrecognized tenant settings: {settings:?}")
}
self.http_request(Method::POST, format!("{}/tenant", self.http_base_url))?
self.http_request(Method::POST, format!("{}/tenant", self.http_base_url))
.json(&request)
.send()?
.error_from_body()?
@@ -384,7 +373,7 @@ impl PageServerNode {
}
pub fn tenant_config(&self, tenant_id: TenantId, settings: HashMap<&str, &str>) -> Result<()> {
self.http_request(Method::PUT, format!("{}/tenant/config", self.http_base_url))?
self.http_request(Method::PUT, format!("{}/tenant/config", self.http_base_url))
.json(&TenantConfigRequest {
tenant_id,
checkpoint_distance: settings
@@ -447,7 +436,7 @@ impl PageServerNode {
.http_request(
Method::GET,
format!("{}/tenant/{}/timeline", self.http_base_url, tenant_id),
)?
)
.send()?
.error_from_body()?
.json()?;
@@ -466,7 +455,7 @@ impl PageServerNode {
self.http_request(
Method::POST,
format!("{}/tenant/{}/timeline", self.http_base_url, tenant_id),
)?
)
.json(&TimelineCreateRequest {
new_timeline_id,
ancestor_start_lsn,
@@ -503,7 +492,7 @@ impl PageServerNode {
pg_wal: Option<(Lsn, PathBuf)>,
pg_version: u32,
) -> anyhow::Result<()> {
let mut client = self.page_server_psql_client()?;
let mut client = self.pg_connection_config.connect_no_tls().unwrap();
// Init base reader
let (start_lsn, base_tarfile_path) = base;

View File

@@ -1,6 +1,7 @@
use std::io::Write;
use std::path::PathBuf;
use std::process::Child;
use std::sync::Arc;
use std::{io, result};
use anyhow::Context;
@@ -10,6 +11,7 @@ use reqwest::{IntoUrl, Method};
use thiserror::Error;
use utils::{http::error::HttpErrorBody, id::NodeId};
use crate::pageserver::PageServerNode;
use crate::{
background_process,
local_env::{LocalEnv, SafekeeperConf},
@@ -63,10 +65,14 @@ pub struct SafekeeperNode {
pub env: LocalEnv,
pub http_client: Client,
pub http_base_url: String,
pub pageserver: Arc<PageServerNode>,
}
impl SafekeeperNode {
pub fn from_env(env: &LocalEnv, conf: &SafekeeperConf) -> SafekeeperNode {
let pageserver = Arc::new(PageServerNode::from_env(env));
SafekeeperNode {
id: conf.id,
conf: conf.clone(),
@@ -74,6 +80,7 @@ impl SafekeeperNode {
env: env.clone(),
http_client: Client::new(),
http_base_url: format!("http://127.0.0.1:{}/v1", conf.http_port),
pageserver,
}
}
@@ -108,10 +115,6 @@ impl SafekeeperNode {
let datadir = self.datadir_path();
let id_string = id.to_string();
// TODO: add availability_zone to the config.
// Right now we just specify any value here and use it to check metrics in tests.
let availability_zone = format!("sk-{}", id_string);
let mut args = vec![
"-D",
datadir.to_str().with_context(|| {
@@ -123,8 +126,6 @@ impl SafekeeperNode {
&listen_pg,
"--listen-http",
&listen_http,
"--availability-zone",
&availability_zone,
];
if !self.conf.sync {
args.push("--no-sync");

View File

@@ -160,7 +160,6 @@ services:
build:
context: ./compute_wrapper/
args:
- REPOSITORY=${REPOSITORY:-neondatabase}
- COMPUTE_IMAGE=compute-node-v${PG_VERSION:-14}
- TAG=${TAG:-latest}
- http_proxy=$http_proxy

View File

@@ -29,54 +29,12 @@ These components should not have access to the private key and may only get toke
The key pair is generated once for an installation of compute/pageserver/safekeeper, e.g. by `neon_local init`.
There is currently no way to rotate the key without bringing down all components.
### Best practices
See [RFC 8725: JSON Web Token Best Current Practices](https://www.rfc-editor.org/rfc/rfc8725)
### Token format
The JWT tokens in Neon use "EdDSA" as the algorithm (defined in [RFC8037](https://www.rfc-editor.org/rfc/rfc8037)).
Example:
Header:
```
{
"alg": "EdDSA",
"typ": "JWT"
}
```
Payload:
```
{
"scope": "tenant", # "tenant", "pageserverapi", or "safekeeperdata"
"tenant_id": "5204921ff44f09de8094a1390a6a50f6",
}
```
Meanings of scope:
"tenant": Provides access to all data for a specific tenant
"pageserverapi": Provides blanket access to all tenants on the pageserver plus pageserver-wide APIs.
Should only be used e.g. for status check/tenant creation/list.
"safekeeperdata": Provides blanket access to all data on the safekeeper plus safekeeper-wide APIs.
Should only be used e.g. for status check.
Currently also used for connection from any pageserver to any safekeeper.
### CLI
CLI generates a key pair during call to `neon_local init` with the following commands:
```bash
openssl genpkey -algorithm ed25519 -out auth_private_key.pem
openssl pkey -in auth_private_key.pem -pubout -out auth_public_key.pem
openssl genrsa -out auth_private_key.pem 2048
openssl rsa -in auth_private_key.pem -pubout -outform PEM -out auth_public_key.pem
```
Configuration files for all components point to `public_key.pem` for JWT validation.
@@ -106,22 +64,20 @@ Their authentication is just plain PostgreSQL authentication and out of scope fo
There is no administrative API except those provided by PostgreSQL.
#### Outgoing connections
Compute connects to Pageserver for getting pages. The connection string is
configured by the `neon.pageserver_connstring` PostgreSQL GUC,
e.g. `postgresql://no_user@localhost:15028`. If the `$NEON_AUTH_TOKEN`
environment variable is set, it is used as the password for the connection. (The
pageserver uses JWT tokens for authentication, so the password is really a
token.)
Compute connects to Pageserver for getting pages.
The connection string is configured by the `neon.pageserver_connstring` PostgreSQL GUC, e.g. `postgresql://no_user:$NEON_AUTH_TOKEN@localhost:15028`.
The environment variable inside the connection string is substituted with
the JWT token.
Compute connects to Safekeepers to write and commit data. The list of safekeeper
addresses is given in the `neon.safekeepers` GUC. The connections to the
safekeepers take the password from the `$NEON_AUTH_TOKEN` environment
variable, if set.
Compute connects to Safekeepers to write and commit data.
The token is the same for all safekeepers.
It's stored in an environment variable, whose name is configured
by the `neon.safekeeper_token_env` PostgreSQL GUC.
If the GUC is unset, no token is passed.
The `compute_ctl` binary that runs before the PostgreSQL server, and launches
PostgreSQL, also makes a connection to the pageserver. It uses it to fetch the
initial "base backup" dump, to initialize the PostgreSQL data directory. It also
uses `$NEON_AUTH_TOKEN` as the password for the connection.
Note that both tokens can be (and typically are) the same;
the scope is the tenant and the token is usually passed through the
`$NEON_AUTH_TOKEN` environment variable.
### Pageserver
#### Overview
@@ -146,12 +102,10 @@ Each compute should present a token valid for the timeline's tenant.
Pageserver also has HTTP API: some parts are per-tenant,
some parts are server-wide, these are different scopes.
Authentication can be enabled separately for the HTTP mgmt API, and
for the libpq connections from compute. The `http_auth_type` and
`pg_auth_type` configuration variables in Pageserver's config may
have one of these values:
The `auth_type` configuration variable in Pageserver's config may have
either of three values:
* `Trust` removes all authentication.
* `Trust` removes all authentication. The outdated `MD5` value does likewise
* `NeonJWT` enables JWT validation.
Tokens are validated using the public key which lies in a PEM file
specified in the `auth_validation_public_key_path` config.

View File

@@ -37,9 +37,9 @@ You can specify version of neon cluster using following environment values.
- PG_VERSION: postgres version for compute (default is 14)
- TAG: the tag version of [docker image](https://registry.hub.docker.com/r/neondatabase/neon/tags) (default is latest), which is tagged in [CI test](/.github/workflows/build_and_test.yml)
```
$ cd docker-compose/
$ cd docker-compose/docker-compose.yml
$ docker-compose down # remove the conainers if exists
$ PG_VERSION=15 TAG=2937 docker-compose up --build -d # You can specify the postgres and image version
$ PG_VERSION=15 TAG=2221 docker-compose up --build -d # You can specify the postgres and image version
Creating network "dockercompose_default" with the default driver
Creating docker-compose_storage_broker_1 ... done
(...omit...)

View File

@@ -1,269 +0,0 @@
# Deleting pageserver part of tenants data from s3
Created on 08.03.23
## Motivation
Currently we dont delete pageserver part of the data from s3 when project is deleted. (The same is true for safekeepers, but this outside of the scope of this RFC).
This RFC aims to spin a discussion to come to a robust deletion solution that wont put us in into a corner for features like postponed deletion (when we keep data for user to be able to restore a project if it was deleted by accident)
## Summary
TLDR; There are two options, one based on control plane issuing actual delete requests to s3 and the other one that keeps s3 stuff bound to pageserver. Each one has its pros and cons.
The decision is to stick with pageserver centric approach. For motivation see [Decision](#decision).
## Components
pageserver, control-plane
## Requirements
Deletion should successfully finish (eventually) without leaving dangling files in presense of:
- component restarts
- component outage
- pageserver loss
## Proposed implementation
Before the options are discussed, note that deletion can be quite long process. For deletion from s3 the obvious choice is [DeleteObjects](https://docs.aws.amazon.com/AmazonS3/latest/API/API_DeleteObjects.html) API call. It allows to batch deletion of up to 1k objects in one API call. So deletion operation linearly depends on number of layer files.
Another design limitation is that there is no cheap `mv` operation available for s3. `mv` from `aws s3 mv` uses `copy(src, dst) + delete(src)`. So `mv`-like operation is not feasible as a building block because it actually amplifies the problem with both duration and resulting cost of the operation.
The case when there are multiple pageservers handling the same tenants is largely out of scope of the RFC. We still consider case with migration from one PS to another, but do not consider case when tenant exists on multiple pageservers for extended period of time. The case with multiple pageservers can be reduced to case with one pageservers by calling detach on all pageservers except the last one, for it actual delete needs to be called.
For simplicity lets look into deleting tenants. Differences in deletion process between tenants and timelines are mentioned in paragraph ["Differences between tenants and timelines"](#differences-between-tenants-and-timelines)
### 1. Pageserver owns deletion machinery
#### The sequence
TLDR; With this approach control plane needs to call delete on a tenant and poll for progress. As much as possible is handled on pageserver. Lets see the sequence.
Happy path:
```mermaid
sequenceDiagram
autonumber
participant CP as Control Plane
participant PS as Pageserver
participant S3
CP->>PS: Delete tenant
PS->>S3: Create deleted mark file at <br> /tenant/meta/deleted
PS->>PS: Create deleted mark file locally
PS->>CP: Accepted
PS->>PS: delete local files other than deleted mark
loop Delete layers for each timeline
PS->>S3: delete(..)
CP->>PS: Finished?
PS->>CP: False
end
PS->>S3: Delete mark file
PS->>PS: Delete local mark file
loop Poll for status
CP->>PS: Finished?
PS->>CP: True or False
end
```
Why two mark files?
Remote one is needed for cases when pageserver is lost during deletion so other pageserver can learn the deletion from s3 during attach.
Why local mark file is needed?
If we dont have one, we have two choices, delete local data before deleting the remote part or do that after.
If we delete local data before remote then during restart pageserver wont pick up remote tenant at all because nothing is available locally (pageserver looks for remote conuterparts of locally available tenants).
If we delete local data after remote then at the end of the sequence when remote mark file is deleted if pageserver restart happens then the state is the same to situation when pageserver just missing data on remote without knowing the fact that this data is intended to be deleted. In this case the current behavior is upload everything local-only to remote.
Thus we need local record of tenant being deleted as well.
##### Handle pageserver crashes
Lets explore sequences with various crash points.
Pageserver crashes before `deleted` mark file is persisted in s3:
```mermaid
sequenceDiagram
autonumber
participant CP as Control Plane
participant PS as Pageserver
participant S3
CP->>PS: Delete tenant
note over PS: Crash point 1.
CP->>PS: Retry delete request
PS->>S3: Create deleted mark file at <br> /tenant/meta/deleted
PS->>PS: Create deleted mark file locally
PS->>CP: Accepted
PS->>PS: delete local files other than deleted mark
loop Delete layers for each timeline
PS->>S3: delete(..)
CP->>PS: Finished?
PS->>CP: False
end
PS->>S3: Delete mark file
PS->>PS: Delete local mark file
CP->>PS: Finished?
PS->>CP: True
```
Pageserver crashed when deleted mark was about to be persisted in s3, before Control Plane gets a response:
```mermaid
sequenceDiagram
autonumber
participant CP as Control Plane
participant PS as Pageserver
participant S3
CP->>PS: Delete tenant
PS->>S3: Create deleted mark file at <br> /tenant/meta/deleted
note over PS: Crash point 2.
note over PS: During startup we reconcile <br> with remote and see <br> whether the remote mark exists
alt Remote mark exists
PS->>PS: create local mark if its missing
PS->>PS: delete local files other than deleted mark
loop Delete layers for each timeline
PS->>S3: delete(..)
end
note over CP: Eventually console should <br> retry delete request
CP->>PS: Retry delete tenant
PS->>CP: Not modified
else Mark is missing
note over PS: Continue to operate the tenant as if deletion didnt happen
note over CP: Eventually console should <br> retry delete request
CP->>PS: Retry delete tenant
PS->>S3: Create deleted mark file at <br> /tenant/meta/deleted
PS->>CP: Delete tenant
end
PS->>PS: Continue with layer file deletions
loop Delete layers for each timeline
PS->>S3: delete(..)
CP->>PS: Finished?
PS->>CP: False
end
PS->>S3: Delete mark file
PS->>PS: Delete local mark file
CP->>PS: Finished?
PS->>CP: True
```
Similar sequence applies when both local and remote marks were persisted but Control Plane still didnt receive a response.
If pageserver crashes after both mark files were deleted then it will reply to control plane status poll request with 404 which should be treated by control plane as success.
The same applies if pageserver crashes in the end, when remote mark is deleted but before local one gets deleted. In this case on restart pageserver moves forward with deletion of local mark and Control Plane will receive 404.
##### Differences between tenants and timelines
For timeline the sequence is the same with the following differences:
- remote delete mark file can be replaced with a boolean "deleted" flag in index_part.json
- local deletion mark is not needed, because whole tenant is kept locally so situation described in motivation for local mark is impossible
##### Handle pageserver loss
If pageseserver is lost then the deleted tenant should be attached to different pageserver and delete request needs to be retried against new pageserver. Then attach logic is shared with one described for pageserver restarts (local deletion mark wont be available so needs to be created).
##### Restrictions for tenant that is in progress of being deleted
I propose to add another state to tenant/timeline - PendingDelete. This state shouldnt allow executing any operations aside from polling the deletion status.
#### Summary
Pros:
- Storage is not dependent on control plane. Storage can be restarted even if control plane is not working.
- Allows for easier dogfooding, console can use Neon backed database as primary operational data store. If storage depends on control plane and control plane depends on storage we're stuck.
- No need to share inner s3 workings with control plane. Pageserver presents api contract and S3 paths are not part of this contract.
- No need to pass list of alive timelines to attach call. This will be solved by pageserver observing deleted flag. See
Cons:
- Logic is a tricky, needs good testing
- Anything else?
### 2. Control plane owns deletion machinery
In this case the only action performed on pageserver is removal of local files.
Everything else is done by control plane. The steps are as follows:
1. Control plane marks tenant as "delete pending" in its database
2. It lists the s3 for all the files and repeatedly calls delete until nothing is left behind
3. When no files are left marks deletion as completed
In case of restart it selects all tenants marked as "delete pending" and continues the deletion.
For tenants it is simple. For timelines there are caveats.
Assume that the same workflow is used for timelines.
If a tenant gets relocated during timeline deletion the attach call with its current logic will pick up deleted timeline in its half deleted state.
Available options:
- require list of alive timelines to be passed to attach call
- use the same schema with flag in index_part.json (again part of the caveats around pageserver restart applies). In this case nothing stops pageserver from implementing deletion inside if we already have these deletion marks.
With first option the following problem becomes apparent:
Who is the source of truth regarding timeline liveness?
Imagine:
PS1 fails.
PS2 gets assigned the tenant.
New branch gets created
PS1 starts up (is it possible or we just recycle it?)
PS1 is unaware of the new branch. It can either fall back to s3 ls, or ask control plane.
So here comes the dependency of storage on control plane. During restart storage needs to know which timelines are valid for operation. If there is nothing on s3 that can answer that question storage neeeds to ask control plane.
### Summary
Cons:
- Potential thundering herd-like problem during storage restart (requests to control plane)
- Potential increase in storage startup time (additional request to control plane)
- Storage startup starts to depend on console
- Erroneous attach call can attach tenant in half deleted state
Pros:
- Easier to reason about if you dont have to account for pageserver restarts
### Extra notes
There was a concern that having deletion code in pageserver is a littlebit scary, but we need to have this code somewhere. So to me it is equally scary to have that in whatever place it ends up at.
Delayed deletion can be done with both approaches. As discussed with Anna (@stepashka) this is only relevant for tenants (projects) not for timelines. For first approach detach can be called immediately and deletion can be done later with attach + delete. With second approach control plane needs to start the deletion whenever necessary.
## Decision
After discussion in comments I see that we settled on two options (though a bit different from ones described in rfc). First one is the same - pageserver owns as much as possible. The second option is that pageserver owns markers thing, but actual deletion happens in control plane by repeatedly calling ls + delete.
To my mind the only benefit of the latter approach is possible code reuse between safekeepers and pageservers. Otherwise poking around integrating s3 library into control plane, configuring shared knowledge abouth paths in s3 - are the downsides. Another downside of relying on control plane is the testing process. Control plane resides in different repository so it is quite hard to test pageserver related changes there. e2e test suite there doesnt support shutting down pageservers, which are separate docker containers there instead of just processes.
With pageserver owning everything we still give the retry logic to control plane but its easier to duplicate if needed compared to sharing inner s3 workings. We will have needed tests for retry logic in neon repo.
So the decision is to proceed with pageserver centric approach.

View File

@@ -129,12 +129,13 @@ Run `poetry shell` to activate the virtual environment.
Alternatively, use `poetry run` to run a single command in the venv, e.g. `poetry run pytest`.
### Obligatory checks
We force code formatting via `black`, `ruff`, and type hints via `mypy`.
We force code formatting via `black`, `isort` and type hints via `mypy`.
Run the following commands in the repository's root (next to `pyproject.toml`):
```bash
poetry run isort . # Imports are reformatted
poetry run black . # All code is reformatted
poetry run ruff . # Python linter
poetry run flake8 . # Python linter
poetry run mypy . # Ensure there are no typing errors
```

View File

@@ -115,11 +115,6 @@ pub struct TenantCreateRequest {
pub lagging_wal_timeout: Option<String>,
pub max_lsn_wal_lag: Option<NonZeroU64>,
pub trace_read_requests: Option<bool>,
// We defer the parsing of the eviction_policy field to the request handler.
// Otherwise we'd have to move the types for eviction policy into this package.
// We might do that once the eviction feature has stabilizied.
// For now, this field is not even documented in the openapi_spec.yml.
pub eviction_policy: Option<serde_json::Value>,
}
#[serde_as]
@@ -346,7 +341,7 @@ pub enum InMemoryLayerInfo {
pub enum HistoricLayerInfo {
Delta {
layer_file_name: String,
layer_file_size: u64,
layer_file_size: Option<u64>,
#[serde_as(as = "DisplayFromStr")]
lsn_start: Lsn,
@@ -357,7 +352,7 @@ pub enum HistoricLayerInfo {
},
Image {
layer_file_name: String,
layer_file_size: u64,
layer_file_size: Option<u64>,
#[serde_as(as = "DisplayFromStr")]
lsn_start: Lsn,

View File

@@ -1,26 +0,0 @@
[package]
name = "postgres_backend"
version = "0.1.0"
edition.workspace = true
license.workspace = true
[dependencies]
async-trait.workspace = true
anyhow.workspace = true
bytes.workspace = true
futures.workspace = true
rustls.workspace = true
serde.workspace = true
thiserror.workspace = true
tokio.workspace = true
tokio-rustls.workspace = true
tracing.workspace = true
pq_proto.workspace = true
workspace_hack.workspace = true
[dev-dependencies]
once_cell.workspace = true
rustls-pemfile.workspace = true
tokio-postgres.workspace = true
tokio-postgres-rustls.workspace = true

View File

@@ -1,931 +0,0 @@
//! Server-side asynchronous Postgres connection, as limited as we need.
//! To use, create PostgresBackend and run() it, passing the Handler
//! implementation determining how to process the queries. Currently its API
//! is rather narrow, but we can extend it once required.
use anyhow::Context;
use bytes::Bytes;
use futures::pin_mut;
use serde::{Deserialize, Serialize};
use std::io::ErrorKind;
use std::net::SocketAddr;
use std::pin::Pin;
use std::sync::Arc;
use std::task::{ready, Poll};
use std::{fmt, io};
use std::{future::Future, str::FromStr};
use tokio::io::{AsyncRead, AsyncWrite};
use tokio_rustls::TlsAcceptor;
use tracing::{debug, error, info, trace};
use pq_proto::framed::{ConnectionError, Framed, FramedReader, FramedWriter};
use pq_proto::{
BeMessage, FeMessage, FeStartupPacket, ProtocolError, SQLSTATE_INTERNAL_ERROR,
SQLSTATE_SUCCESSFUL_COMPLETION,
};
/// An error, occurred during query processing:
/// either during the connection ([`ConnectionError`]) or before/after it.
#[derive(thiserror::Error, Debug)]
pub enum QueryError {
/// The connection was lost while processing the query.
#[error(transparent)]
Disconnected(#[from] ConnectionError),
/// Some other error
#[error(transparent)]
Other(#[from] anyhow::Error),
}
impl From<io::Error> for QueryError {
fn from(e: io::Error) -> Self {
Self::Disconnected(ConnectionError::Io(e))
}
}
impl QueryError {
pub fn pg_error_code(&self) -> &'static [u8; 5] {
match self {
Self::Disconnected(_) => b"08006", // connection failure
Self::Other(_) => SQLSTATE_INTERNAL_ERROR, // internal error
}
}
}
pub fn is_expected_io_error(e: &io::Error) -> bool {
use io::ErrorKind::*;
matches!(
e.kind(),
ConnectionRefused | ConnectionAborted | ConnectionReset
)
}
#[async_trait::async_trait]
pub trait Handler<IO> {
/// Handle single query.
/// postgres_backend will issue ReadyForQuery after calling this (this
/// might be not what we want after CopyData streaming, but currently we don't
/// care). It will also flush out the output buffer.
async fn process_query(
&mut self,
pgb: &mut PostgresBackend<IO>,
query_string: &str,
) -> Result<(), QueryError>;
/// Called on startup packet receival, allows to process params.
///
/// If Ok(false) is returned postgres_backend will skip auth -- that is needed for new users
/// creation is the proxy code. That is quite hacky and ad-hoc solution, may be we could allow
/// to override whole init logic in implementations.
fn startup(
&mut self,
_pgb: &mut PostgresBackend<IO>,
_sm: &FeStartupPacket,
) -> Result<(), QueryError> {
Ok(())
}
/// Check auth jwt
fn check_auth_jwt(
&mut self,
_pgb: &mut PostgresBackend<IO>,
_jwt_response: &[u8],
) -> Result<(), QueryError> {
Err(QueryError::Other(anyhow::anyhow!("JWT auth failed")))
}
}
/// PostgresBackend protocol state.
/// XXX: The order of the constructors matters.
#[derive(Clone, Copy, PartialEq, Eq, PartialOrd)]
pub enum ProtoState {
/// Nothing happened yet.
Initialization,
/// Encryption handshake is done; waiting for encrypted Startup message.
Encrypted,
/// Waiting for password (auth token).
Authentication,
/// Performed handshake and auth, ReadyForQuery is issued.
Established,
Closed,
}
#[derive(Clone, Copy)]
pub enum ProcessMsgResult {
Continue,
Break,
}
/// Either plain TCP stream or encrypted one, implementing AsyncRead + AsyncWrite.
pub enum MaybeTlsStream<IO> {
Unencrypted(IO),
Tls(Box<tokio_rustls::server::TlsStream<IO>>),
}
impl<IO: AsyncRead + AsyncWrite + Unpin> AsyncWrite for MaybeTlsStream<IO> {
fn poll_write(
self: Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
buf: &[u8],
) -> Poll<io::Result<usize>> {
match self.get_mut() {
Self::Unencrypted(stream) => Pin::new(stream).poll_write(cx, buf),
Self::Tls(stream) => Pin::new(stream).poll_write(cx, buf),
}
}
fn poll_flush(self: Pin<&mut Self>, cx: &mut std::task::Context<'_>) -> Poll<io::Result<()>> {
match self.get_mut() {
Self::Unencrypted(stream) => Pin::new(stream).poll_flush(cx),
Self::Tls(stream) => Pin::new(stream).poll_flush(cx),
}
}
fn poll_shutdown(
self: Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
) -> Poll<io::Result<()>> {
match self.get_mut() {
Self::Unencrypted(stream) => Pin::new(stream).poll_shutdown(cx),
Self::Tls(stream) => Pin::new(stream).poll_shutdown(cx),
}
}
}
impl<IO: AsyncRead + AsyncWrite + Unpin> AsyncRead for MaybeTlsStream<IO> {
fn poll_read(
self: Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
buf: &mut tokio::io::ReadBuf<'_>,
) -> Poll<io::Result<()>> {
match self.get_mut() {
Self::Unencrypted(stream) => Pin::new(stream).poll_read(cx, buf),
Self::Tls(stream) => Pin::new(stream).poll_read(cx, buf),
}
}
}
#[derive(Debug, PartialEq, Eq, Clone, Copy, Serialize, Deserialize)]
pub enum AuthType {
Trust,
// This mimics postgres's AuthenticationCleartextPassword but instead of password expects JWT
NeonJWT,
}
impl FromStr for AuthType {
type Err = anyhow::Error;
fn from_str(s: &str) -> Result<Self, Self::Err> {
match s {
"Trust" => Ok(Self::Trust),
"NeonJWT" => Ok(Self::NeonJWT),
_ => anyhow::bail!("invalid value \"{s}\" for auth type"),
}
}
}
impl fmt::Display for AuthType {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.write_str(match self {
AuthType::Trust => "Trust",
AuthType::NeonJWT => "NeonJWT",
})
}
}
/// Either full duplex Framed or write only half; the latter is left in
/// PostgresBackend after call to `split`. In principle we could always store a
/// pair of splitted handles, but that would force to to pay splitting price
/// (Arc and kinda mutex inside polling) for all uses (e.g. pageserver).
enum MaybeWriteOnly<IO> {
Full(Framed<MaybeTlsStream<IO>>),
WriteOnly(FramedWriter<MaybeTlsStream<IO>>),
Broken, // temporary value palmed off during the split
}
impl<IO: AsyncRead + AsyncWrite + Unpin> MaybeWriteOnly<IO> {
async fn read_startup_message(&mut self) -> Result<Option<FeStartupPacket>, ConnectionError> {
match self {
MaybeWriteOnly::Full(framed) => framed.read_startup_message().await,
MaybeWriteOnly::WriteOnly(_) => {
Err(io::Error::new(ErrorKind::Other, "reading from write only half").into())
}
MaybeWriteOnly::Broken => panic!("IO on invalid MaybeWriteOnly"),
}
}
async fn read_message(&mut self) -> Result<Option<FeMessage>, ConnectionError> {
match self {
MaybeWriteOnly::Full(framed) => framed.read_message().await,
MaybeWriteOnly::WriteOnly(_) => {
Err(io::Error::new(ErrorKind::Other, "reading from write only half").into())
}
MaybeWriteOnly::Broken => panic!("IO on invalid MaybeWriteOnly"),
}
}
fn write_message_noflush(&mut self, msg: &BeMessage<'_>) -> Result<(), ProtocolError> {
match self {
MaybeWriteOnly::Full(framed) => framed.write_message(msg),
MaybeWriteOnly::WriteOnly(framed_writer) => framed_writer.write_message_noflush(msg),
MaybeWriteOnly::Broken => panic!("IO on invalid MaybeWriteOnly"),
}
}
async fn flush(&mut self) -> io::Result<()> {
match self {
MaybeWriteOnly::Full(framed) => framed.flush().await,
MaybeWriteOnly::WriteOnly(framed_writer) => framed_writer.flush().await,
MaybeWriteOnly::Broken => panic!("IO on invalid MaybeWriteOnly"),
}
}
async fn shutdown(&mut self) -> io::Result<()> {
match self {
MaybeWriteOnly::Full(framed) => framed.shutdown().await,
MaybeWriteOnly::WriteOnly(framed_writer) => framed_writer.shutdown().await,
MaybeWriteOnly::Broken => panic!("IO on invalid MaybeWriteOnly"),
}
}
}
pub struct PostgresBackend<IO> {
framed: MaybeWriteOnly<IO>,
pub state: ProtoState,
auth_type: AuthType,
peer_addr: SocketAddr,
pub tls_config: Option<Arc<rustls::ServerConfig>>,
}
pub type PostgresBackendTCP = PostgresBackend<tokio::net::TcpStream>;
pub fn query_from_cstring(query_string: Bytes) -> Vec<u8> {
let mut query_string = query_string.to_vec();
if let Some(ch) = query_string.last() {
if *ch == 0 {
query_string.pop();
}
}
query_string
}
/// Cast a byte slice to a string slice, dropping null terminator if there's one.
fn cstr_to_str(bytes: &[u8]) -> anyhow::Result<&str> {
let without_null = bytes.strip_suffix(&[0]).unwrap_or(bytes);
std::str::from_utf8(without_null).map_err(|e| e.into())
}
impl PostgresBackend<tokio::net::TcpStream> {
pub fn new(
socket: tokio::net::TcpStream,
auth_type: AuthType,
tls_config: Option<Arc<rustls::ServerConfig>>,
) -> io::Result<Self> {
let peer_addr = socket.peer_addr()?;
let stream = MaybeTlsStream::Unencrypted(socket);
Ok(Self {
framed: MaybeWriteOnly::Full(Framed::new(stream)),
state: ProtoState::Initialization,
auth_type,
tls_config,
peer_addr,
})
}
}
impl<IO: AsyncRead + AsyncWrite + Unpin> PostgresBackend<IO> {
pub fn new_from_io(
socket: IO,
peer_addr: SocketAddr,
auth_type: AuthType,
tls_config: Option<Arc<rustls::ServerConfig>>,
) -> io::Result<Self> {
let stream = MaybeTlsStream::Unencrypted(socket);
Ok(Self {
framed: MaybeWriteOnly::Full(Framed::new(stream)),
state: ProtoState::Initialization,
auth_type,
tls_config,
peer_addr,
})
}
pub fn get_peer_addr(&self) -> &SocketAddr {
&self.peer_addr
}
/// Read full message or return None if connection is cleanly closed with no
/// unprocessed data.
pub async fn read_message(&mut self) -> Result<Option<FeMessage>, ConnectionError> {
if let ProtoState::Closed = self.state {
Ok(None)
} else {
let m = self.framed.read_message().await?;
trace!("read msg {:?}", m);
Ok(m)
}
}
/// Write message into internal output buffer, doesn't flush it. Technically
/// error type can be only ProtocolError here (if, unlikely, serialization
/// fails), but callers typically wrap it anyway.
pub fn write_message_noflush(
&mut self,
message: &BeMessage<'_>,
) -> Result<&mut Self, ConnectionError> {
self.framed.write_message_noflush(message)?;
trace!("wrote msg {:?}", message);
Ok(self)
}
/// Flush output buffer into the socket.
pub async fn flush(&mut self) -> io::Result<()> {
self.framed.flush().await
}
/// Polling version of `flush()`, saves the caller need to pin.
pub fn poll_flush(
&mut self,
cx: &mut std::task::Context<'_>,
) -> Poll<Result<(), std::io::Error>> {
let flush_fut = self.flush();
pin_mut!(flush_fut);
flush_fut.poll(cx)
}
/// Write message into internal output buffer and flush it to the stream.
pub async fn write_message(
&mut self,
message: &BeMessage<'_>,
) -> Result<&mut Self, ConnectionError> {
self.write_message_noflush(message)?;
self.flush().await?;
Ok(self)
}
/// Returns an AsyncWrite implementation that wraps all the data written
/// to it in CopyData messages, and writes them to the connection
///
/// The caller is responsible for sending CopyOutResponse and CopyDone messages.
pub fn copyout_writer(&mut self) -> CopyDataWriter<IO> {
CopyDataWriter { pgb: self }
}
/// Wrapper for run_message_loop() that shuts down socket when we are done
pub async fn run<F, S>(
mut self,
handler: &mut impl Handler<IO>,
shutdown_watcher: F,
) -> Result<(), QueryError>
where
F: Fn() -> S,
S: Future,
{
let ret = self.run_message_loop(handler, shutdown_watcher).await;
// socket might be already closed, e.g. if previously received error,
// so ignore result.
self.framed.shutdown().await.ok();
ret
}
async fn run_message_loop<F, S>(
&mut self,
handler: &mut impl Handler<IO>,
shutdown_watcher: F,
) -> Result<(), QueryError>
where
F: Fn() -> S,
S: Future,
{
trace!("postgres backend to {:?} started", self.peer_addr);
tokio::select!(
biased;
_ = shutdown_watcher() => {
// We were requested to shut down.
tracing::info!("shutdown request received during handshake");
return Ok(())
},
result = self.handshake(handler) => {
// Handshake complete.
result?;
if self.state == ProtoState::Closed {
return Ok(()); // EOF during handshake
}
}
);
// Authentication completed
let mut query_string = Bytes::new();
while let Some(msg) = tokio::select!(
biased;
_ = shutdown_watcher() => {
// We were requested to shut down.
tracing::info!("shutdown request received in run_message_loop");
Ok(None)
},
msg = self.read_message() => { msg },
)? {
trace!("got message {:?}", msg);
let result = self.process_message(handler, msg, &mut query_string).await;
self.flush().await?;
match result? {
ProcessMsgResult::Continue => {
self.flush().await?;
continue;
}
ProcessMsgResult::Break => break,
}
}
trace!("postgres backend to {:?} exited", self.peer_addr);
Ok(())
}
/// Try to upgrade MaybeTlsStream into actual TLS one, performing handshake.
async fn tls_upgrade(
src: MaybeTlsStream<IO>,
tls_config: Arc<rustls::ServerConfig>,
) -> anyhow::Result<MaybeTlsStream<IO>> {
match src {
MaybeTlsStream::Unencrypted(s) => {
let acceptor = TlsAcceptor::from(tls_config);
let tls_stream = acceptor.accept(s).await?;
Ok(MaybeTlsStream::Tls(Box::new(tls_stream)))
}
MaybeTlsStream::Tls(_) => {
anyhow::bail!("TLS already started");
}
}
}
async fn start_tls(&mut self) -> anyhow::Result<()> {
// temporary replace stream with fake to cook TLS one, Indiana Jones style
match std::mem::replace(&mut self.framed, MaybeWriteOnly::Broken) {
MaybeWriteOnly::Full(framed) => {
let tls_config = self
.tls_config
.as_ref()
.context("start_tls called without conf")?
.clone();
let tls_framed = framed
.map_stream(|s| PostgresBackend::tls_upgrade(s, tls_config))
.await?;
// push back ready TLS stream
self.framed = MaybeWriteOnly::Full(tls_framed);
Ok(())
}
MaybeWriteOnly::WriteOnly(_) => {
anyhow::bail!("TLS upgrade attempt in split state")
}
MaybeWriteOnly::Broken => panic!("TLS upgrade on framed in invalid state"),
}
}
/// Split off owned read part from which messages can be read in different
/// task/thread.
pub fn split(&mut self) -> anyhow::Result<PostgresBackendReader<IO>> {
// temporary replace stream with fake to cook split one, Indiana Jones style
match std::mem::replace(&mut self.framed, MaybeWriteOnly::Broken) {
MaybeWriteOnly::Full(framed) => {
let (reader, writer) = framed.split();
self.framed = MaybeWriteOnly::WriteOnly(writer);
Ok(PostgresBackendReader(reader))
}
MaybeWriteOnly::WriteOnly(_) => {
anyhow::bail!("PostgresBackend is already split")
}
MaybeWriteOnly::Broken => panic!("split on framed in invalid state"),
}
}
/// Join read part back.
pub fn unsplit(&mut self, reader: PostgresBackendReader<IO>) -> anyhow::Result<()> {
// temporary replace stream with fake to cook joined one, Indiana Jones style
match std::mem::replace(&mut self.framed, MaybeWriteOnly::Broken) {
MaybeWriteOnly::Full(_) => {
anyhow::bail!("PostgresBackend is not split")
}
MaybeWriteOnly::WriteOnly(writer) => {
let joined = Framed::unsplit(reader.0, writer);
self.framed = MaybeWriteOnly::Full(joined);
Ok(())
}
MaybeWriteOnly::Broken => panic!("unsplit on framed in invalid state"),
}
}
/// Perform handshake with the client, transitioning to Established.
/// In case of EOF during handshake logs this, sets state to Closed and returns Ok(()).
async fn handshake(&mut self, handler: &mut impl Handler<IO>) -> Result<(), QueryError> {
while self.state < ProtoState::Authentication {
match self.framed.read_startup_message().await? {
Some(msg) => {
self.process_startup_message(handler, msg).await?;
}
None => {
trace!(
"postgres backend to {:?} received EOF during handshake",
self.peer_addr
);
self.state = ProtoState::Closed;
return Ok(());
}
}
}
// Perform auth, if needed.
if self.state == ProtoState::Authentication {
match self.framed.read_message().await? {
Some(FeMessage::PasswordMessage(m)) => {
assert!(self.auth_type == AuthType::NeonJWT);
let (_, jwt_response) = m.split_last().context("protocol violation")?;
if let Err(e) = handler.check_auth_jwt(self, jwt_response) {
self.write_message_noflush(&BeMessage::ErrorResponse(
&e.to_string(),
Some(e.pg_error_code()),
))?;
return Err(e);
}
self.write_message_noflush(&BeMessage::AuthenticationOk)?
.write_message_noflush(&BeMessage::CLIENT_ENCODING)?
.write_message(&BeMessage::ReadyForQuery)
.await?;
self.state = ProtoState::Established;
}
Some(m) => {
return Err(QueryError::Other(anyhow::anyhow!(
"Unexpected message {:?} while waiting for handshake",
m
)));
}
None => {
trace!(
"postgres backend to {:?} received EOF during auth",
self.peer_addr
);
self.state = ProtoState::Closed;
return Ok(());
}
}
}
Ok(())
}
/// Process startup packet:
/// - transition to Established if auth type is trust
/// - transition to Authentication if auth type is NeonJWT.
/// - or perform TLS handshake -- then need to call this again to receive
/// actual startup packet.
async fn process_startup_message(
&mut self,
handler: &mut impl Handler<IO>,
msg: FeStartupPacket,
) -> Result<(), QueryError> {
assert!(self.state < ProtoState::Authentication);
let have_tls = self.tls_config.is_some();
match msg {
FeStartupPacket::SslRequest => {
debug!("SSL requested");
self.write_message(&BeMessage::EncryptionResponse(have_tls))
.await?;
if have_tls {
self.start_tls().await?;
self.state = ProtoState::Encrypted;
}
}
FeStartupPacket::GssEncRequest => {
debug!("GSS requested");
self.write_message(&BeMessage::EncryptionResponse(false))
.await?;
}
FeStartupPacket::StartupMessage { .. } => {
if have_tls && !matches!(self.state, ProtoState::Encrypted) {
self.write_message(&BeMessage::ErrorResponse("must connect with TLS", None))
.await?;
return Err(QueryError::Other(anyhow::anyhow!(
"client did not connect with TLS"
)));
}
// NB: startup() may change self.auth_type -- we are using that in proxy code
// to bypass auth for new users.
handler.startup(self, &msg)?;
match self.auth_type {
AuthType::Trust => {
self.write_message_noflush(&BeMessage::AuthenticationOk)?
.write_message_noflush(&BeMessage::CLIENT_ENCODING)?
.write_message_noflush(&BeMessage::INTEGER_DATETIMES)?
// The async python driver requires a valid server_version
.write_message_noflush(&BeMessage::server_version("14.1"))?
.write_message(&BeMessage::ReadyForQuery)
.await?;
self.state = ProtoState::Established;
}
AuthType::NeonJWT => {
self.write_message(&BeMessage::AuthenticationCleartextPassword)
.await?;
self.state = ProtoState::Authentication;
}
}
}
FeStartupPacket::CancelRequest { .. } => {
return Err(QueryError::Other(anyhow::anyhow!(
"Unexpected CancelRequest message during handshake"
)));
}
}
Ok(())
}
async fn process_message(
&mut self,
handler: &mut impl Handler<IO>,
msg: FeMessage,
unnamed_query_string: &mut Bytes,
) -> Result<ProcessMsgResult, QueryError> {
// Allow only startup and password messages during auth. Otherwise client would be able to bypass auth
// TODO: change that to proper top-level match of protocol state with separate message handling for each state
assert!(self.state == ProtoState::Established);
match msg {
FeMessage::Query(body) => {
// remove null terminator
let query_string = cstr_to_str(&body)?;
trace!("got query {query_string:?}");
if let Err(e) = handler.process_query(self, query_string).await {
log_query_error(query_string, &e);
let short_error = short_error(&e);
self.write_message_noflush(&BeMessage::ErrorResponse(
&short_error,
Some(e.pg_error_code()),
))?;
}
self.write_message_noflush(&BeMessage::ReadyForQuery)?;
}
FeMessage::Parse(m) => {
*unnamed_query_string = m.query_string;
self.write_message_noflush(&BeMessage::ParseComplete)?;
}
FeMessage::Describe(_) => {
self.write_message_noflush(&BeMessage::ParameterDescription)?
.write_message_noflush(&BeMessage::NoData)?;
}
FeMessage::Bind(_) => {
self.write_message_noflush(&BeMessage::BindComplete)?;
}
FeMessage::Close(_) => {
self.write_message_noflush(&BeMessage::CloseComplete)?;
}
FeMessage::Execute(_) => {
let query_string = cstr_to_str(unnamed_query_string)?;
trace!("got execute {query_string:?}");
if let Err(e) = handler.process_query(self, query_string).await {
log_query_error(query_string, &e);
self.write_message_noflush(&BeMessage::ErrorResponse(
&e.to_string(),
Some(e.pg_error_code()),
))?;
}
// NOTE there is no ReadyForQuery message. This handler is used
// for basebackup and it uses CopyOut which doesn't require
// ReadyForQuery message and backend just switches back to
// processing mode after sending CopyDone or ErrorResponse.
}
FeMessage::Sync => {
self.write_message_noflush(&BeMessage::ReadyForQuery)?;
}
FeMessage::Terminate => {
return Ok(ProcessMsgResult::Break);
}
// We prefer explicit pattern matching to wildcards, because
// this helps us spot the places where new variants are missing
FeMessage::CopyData(_)
| FeMessage::CopyDone
| FeMessage::CopyFail
| FeMessage::PasswordMessage(_) => {
return Err(QueryError::Other(anyhow::anyhow!(
"unexpected message type: {msg:?}",
)));
}
}
Ok(ProcessMsgResult::Continue)
}
/// Log as info/error result of handling COPY stream and send back
/// ErrorResponse if that makes sense. Shutdown the stream if we got
/// Terminate. TODO: transition into waiting for Sync msg if we initiate the
/// close.
pub async fn handle_copy_stream_end(&mut self, end: CopyStreamHandlerEnd) {
use CopyStreamHandlerEnd::*;
let expected_end = match &end {
ServerInitiated(_) | CopyDone | CopyFail | Terminate | EOF => true,
CopyStreamHandlerEnd::Disconnected(ConnectionError::Io(io_error))
if is_expected_io_error(io_error) =>
{
true
}
_ => false,
};
if expected_end {
info!("terminated: {:#}", end);
} else {
error!("terminated: {:?}", end);
}
// Note: no current usages ever send this
if let CopyDone = &end {
if let Err(e) = self.write_message(&BeMessage::CopyDone).await {
error!("failed to send CopyDone: {}", e);
}
}
if let Terminate = &end {
self.state = ProtoState::Closed;
}
let err_to_send_and_errcode = match &end {
ServerInitiated(_) => Some((end.to_string(), SQLSTATE_SUCCESSFUL_COMPLETION)),
Other(_) => Some((format!("{end:#}"), SQLSTATE_INTERNAL_ERROR)),
// Note: CopyFail in duplex copy is somewhat unexpected (at least to
// PG walsender; evidently and per my docs reading client should
// finish it with CopyDone). It is not a problem to recover from it
// finishing the stream in both directions like we do, but note that
// sync rust-postgres client (which we don't use anymore) hangs if
// socket is not closed here.
// https://github.com/sfackler/rust-postgres/issues/755
// https://github.com/neondatabase/neon/issues/935
//
// Currently, the version of tokio_postgres replication patch we use
// sends this when it closes the stream (e.g. pageserver decided to
// switch conn to another safekeeper and client gets dropped).
// Moreover, seems like 'connection' task errors with 'unexpected
// message from server' when it receives ErrorResponse (anything but
// CopyData/CopyDone) back.
CopyFail => Some((end.to_string(), SQLSTATE_SUCCESSFUL_COMPLETION)),
_ => None,
};
if let Some((err, errcode)) = err_to_send_and_errcode {
if let Err(ee) = self
.write_message(&BeMessage::ErrorResponse(&err, Some(errcode)))
.await
{
error!("failed to send ErrorResponse: {}", ee);
}
}
}
}
pub struct PostgresBackendReader<IO>(FramedReader<MaybeTlsStream<IO>>);
impl<IO: AsyncRead + AsyncWrite + Unpin> PostgresBackendReader<IO> {
/// Read full message or return None if connection is cleanly closed with no
/// unprocessed data.
pub async fn read_message(&mut self) -> Result<Option<FeMessage>, ConnectionError> {
let m = self.0.read_message().await?;
trace!("read msg {:?}", m);
Ok(m)
}
/// Get CopyData contents of the next message in COPY stream or error
/// closing it. The error type is wider than actual errors which can happen
/// here -- it includes 'Other' and 'ServerInitiated', but that's ok for
/// current callers.
pub async fn read_copy_message(&mut self) -> Result<Bytes, CopyStreamHandlerEnd> {
match self.read_message().await? {
Some(msg) => match msg {
FeMessage::CopyData(m) => Ok(m),
FeMessage::CopyDone => Err(CopyStreamHandlerEnd::CopyDone),
FeMessage::CopyFail => Err(CopyStreamHandlerEnd::CopyFail),
FeMessage::Terminate => Err(CopyStreamHandlerEnd::Terminate),
_ => Err(CopyStreamHandlerEnd::from(ConnectionError::Protocol(
ProtocolError::Protocol(format!("unexpected message in COPY stream {:?}", msg)),
))),
},
None => Err(CopyStreamHandlerEnd::EOF),
}
}
}
///
/// A futures::AsyncWrite implementation that wraps all data written to it in CopyData
/// messages.
///
pub struct CopyDataWriter<'a, IO> {
pgb: &'a mut PostgresBackend<IO>,
}
impl<'a, IO: AsyncRead + AsyncWrite + Unpin> AsyncWrite for CopyDataWriter<'a, IO> {
fn poll_write(
self: Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
buf: &[u8],
) -> Poll<Result<usize, std::io::Error>> {
let this = self.get_mut();
// It's not strictly required to flush between each message, but makes it easier
// to view in wireshark, and usually the messages that the callers write are
// decently-sized anyway.
if let Err(err) = ready!(this.pgb.poll_flush(cx)) {
return Poll::Ready(Err(err));
}
// CopyData
// XXX: if the input is large, we should split it into multiple messages.
// Not sure what the threshold should be, but the ultimate hard limit is that
// the length cannot exceed u32.
this.pgb
.write_message_noflush(&BeMessage::CopyData(buf))
// write_message only writes to the buffer, so it can fail iff the
// message is invaid, but CopyData can't be invalid.
.map_err(|_| io::Error::new(ErrorKind::Other, "failed to serialize CopyData"))?;
Poll::Ready(Ok(buf.len()))
}
fn poll_flush(
self: Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
) -> Poll<Result<(), std::io::Error>> {
let this = self.get_mut();
this.pgb.poll_flush(cx)
}
fn poll_shutdown(
self: Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
) -> Poll<Result<(), std::io::Error>> {
let this = self.get_mut();
this.pgb.poll_flush(cx)
}
}
pub fn short_error(e: &QueryError) -> String {
match e {
QueryError::Disconnected(connection_error) => connection_error.to_string(),
QueryError::Other(e) => format!("{e:#}"),
}
}
fn log_query_error(query: &str, e: &QueryError) {
match e {
QueryError::Disconnected(ConnectionError::Io(io_error)) => {
if is_expected_io_error(io_error) {
info!("query handler for '{query}' failed with expected io error: {io_error}");
} else {
error!("query handler for '{query}' failed with io error: {io_error}");
}
}
QueryError::Disconnected(other_connection_error) => {
error!("query handler for '{query}' failed with connection error: {other_connection_error:?}")
}
QueryError::Other(e) => {
error!("query handler for '{query}' failed: {e:?}");
}
}
}
/// Something finishing handling of COPY stream, see handle_copy_stream_end.
/// This is not always a real error, but it allows to use ? and thiserror impls.
#[derive(thiserror::Error, Debug)]
pub enum CopyStreamHandlerEnd {
/// Handler initiates the end of streaming.
#[error("{0}")]
ServerInitiated(String),
#[error("received CopyDone")]
CopyDone,
#[error("received CopyFail")]
CopyFail,
#[error("received Terminate")]
Terminate,
#[error("EOF on COPY stream")]
EOF,
/// The connection was lost
#[error(transparent)]
Disconnected(#[from] ConnectionError),
/// Some other error
#[error(transparent)]
Other(#[from] anyhow::Error),
}

View File

@@ -1,140 +0,0 @@
/// Test postgres_backend_async with tokio_postgres
use once_cell::sync::Lazy;
use postgres_backend::{AuthType, Handler, PostgresBackend, QueryError};
use pq_proto::{BeMessage, RowDescriptor};
use std::io::Cursor;
use std::{future, sync::Arc};
use tokio::io::{AsyncRead, AsyncWrite};
use tokio::net::{TcpListener, TcpStream};
use tokio_postgres::config::SslMode;
use tokio_postgres::tls::MakeTlsConnect;
use tokio_postgres::{Config, NoTls, SimpleQueryMessage};
use tokio_postgres_rustls::MakeRustlsConnect;
// generate client, server test streams
async fn make_tcp_pair() -> (TcpStream, TcpStream) {
let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
let addr = listener.local_addr().unwrap();
let client_stream = TcpStream::connect(addr).await.unwrap();
let (server_stream, _) = listener.accept().await.unwrap();
(client_stream, server_stream)
}
struct TestHandler {}
#[async_trait::async_trait]
impl<IO: AsyncRead + AsyncWrite + Unpin + Send> Handler<IO> for TestHandler {
// return single col 'hey' for any query
async fn process_query(
&mut self,
pgb: &mut PostgresBackend<IO>,
_query_string: &str,
) -> Result<(), QueryError> {
pgb.write_message_noflush(&BeMessage::RowDescription(&[RowDescriptor::text_col(
b"hey",
)]))?
.write_message_noflush(&BeMessage::DataRow(&[Some("hey".as_bytes())]))?
.write_message_noflush(&BeMessage::CommandComplete(b"SELECT 1"))?;
Ok(())
}
}
// test that basic select works
#[tokio::test]
async fn simple_select() {
let (client_sock, server_sock) = make_tcp_pair().await;
// create and run pgbackend
let pgbackend =
PostgresBackend::new(server_sock, AuthType::Trust, None).expect("pgbackend creation");
tokio::spawn(async move {
let mut handler = TestHandler {};
pgbackend.run(&mut handler, future::pending::<()>).await
});
let conf = Config::new();
let (client, connection) = conf.connect_raw(client_sock, NoTls).await.expect("connect");
// The connection object performs the actual communication with the database,
// so spawn it off to run on its own.
tokio::spawn(async move {
if let Err(e) = connection.await {
eprintln!("connection error: {}", e);
}
});
let first_val = &(client.simple_query("SELECT 42;").await.expect("select"))[0];
if let SimpleQueryMessage::Row(row) = first_val {
let first_col = row.get(0).expect("first column");
assert_eq!(first_col, "hey");
} else {
panic!("expected SimpleQueryMessage::Row");
}
}
static KEY: Lazy<rustls::PrivateKey> = Lazy::new(|| {
let mut cursor = Cursor::new(include_bytes!("key.pem"));
rustls::PrivateKey(rustls_pemfile::rsa_private_keys(&mut cursor).unwrap()[0].clone())
});
static CERT: Lazy<rustls::Certificate> = Lazy::new(|| {
let mut cursor = Cursor::new(include_bytes!("cert.pem"));
rustls::Certificate(rustls_pemfile::certs(&mut cursor).unwrap()[0].clone())
});
// test that basic select with ssl works
#[tokio::test]
async fn simple_select_ssl() {
let (client_sock, server_sock) = make_tcp_pair().await;
let server_cfg = rustls::ServerConfig::builder()
.with_safe_defaults()
.with_no_client_auth()
.with_single_cert(vec![CERT.clone()], KEY.clone())
.unwrap();
let tls_config = Some(Arc::new(server_cfg));
let pgbackend =
PostgresBackend::new(server_sock, AuthType::Trust, tls_config).expect("pgbackend creation");
tokio::spawn(async move {
let mut handler = TestHandler {};
pgbackend.run(&mut handler, future::pending::<()>).await
});
let client_cfg = rustls::ClientConfig::builder()
.with_safe_defaults()
.with_root_certificates({
let mut store = rustls::RootCertStore::empty();
store.add(&CERT).unwrap();
store
})
.with_no_client_auth();
let mut make_tls_connect = tokio_postgres_rustls::MakeRustlsConnect::new(client_cfg);
let tls_connect = <MakeRustlsConnect as MakeTlsConnect<TcpStream>>::make_tls_connect(
&mut make_tls_connect,
"localhost",
)
.expect("make_tls_connect");
let mut conf = Config::new();
conf.ssl_mode(SslMode::Require);
let (client, connection) = conf
.connect_raw(client_sock, tls_connect)
.await
.expect("connect");
// The connection object performs the actual communication with the database,
// so spawn it off to run on its own.
tokio::spawn(async move {
if let Err(e) = connection.await {
eprintln!("connection error: {}", e);
}
});
let first_val = &(client.simple_query("SELECT 42;").await.expect("select"))[0];
if let SimpleQueryMessage::Row(row) = first_val {
let first_col = row.get(0).expect("first column");
assert_eq!(first_col, "hey");
} else {
panic!("expected SimpleQueryMessage::Row");
}
}

View File

@@ -63,7 +63,10 @@ fn main() -> anyhow::Result<()> {
pg_install_dir_versioned = cwd.join("..").join("..").join(pg_install_dir_versioned);
}
let pg_config_bin = pg_install_dir_versioned.join("bin").join("pg_config");
let pg_config_bin = pg_install_dir_versioned
.join(pg_version)
.join("bin")
.join("pg_config");
let inc_server_path: String = if pg_config_bin.exists() {
let output = Command::new(pg_config_bin)
.arg("--includedir-server")

View File

@@ -5,8 +5,8 @@ edition.workspace = true
license.workspace = true
[dependencies]
anyhow.workspace = true
bytes.workspace = true
byteorder.workspace = true
pin-project-lite.workspace = true
postgres-protocol.workspace = true
rand.workspace = true

View File

@@ -1,244 +0,0 @@
//! Provides `Framed` -- writing/flushing and reading Postgres messages to/from
//! the async stream based on (and buffered with) BytesMut. All functions are
//! cancellation safe.
//!
//! It is similar to what tokio_util::codec::Framed with appropriate codec
//! provides, but `FramedReader` and `FramedWriter` read/write parts can be used
//! separately without using split from futures::stream::StreamExt (which
//! allocates box[1] in polling internally). tokio::io::split is used for splitting
//! instead. Plus we customize error messages more than a single type for all io
//! calls.
//!
//! [1] https://docs.rs/futures-util/0.3.26/src/futures_util/lock/bilock.rs.html#107
use bytes::{Buf, BytesMut};
use std::{
future::Future,
io::{self, ErrorKind},
};
use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt, ReadHalf, WriteHalf};
use crate::{BeMessage, FeMessage, FeStartupPacket, ProtocolError};
const INITIAL_CAPACITY: usize = 8 * 1024;
/// Error on postgres connection: either IO (physical transport error) or
/// protocol violation.
#[derive(thiserror::Error, Debug)]
pub enum ConnectionError {
#[error(transparent)]
Io(#[from] io::Error),
#[error(transparent)]
Protocol(#[from] ProtocolError),
}
impl ConnectionError {
/// Proxy stream.rs uses only io::Error; provide it.
pub fn into_io_error(self) -> io::Error {
match self {
ConnectionError::Io(io) => io,
ConnectionError::Protocol(pe) => io::Error::new(io::ErrorKind::Other, pe.to_string()),
}
}
}
/// Wraps async io `stream`, providing messages to write/flush + read Postgres
/// messages.
pub struct Framed<S> {
stream: S,
read_buf: BytesMut,
write_buf: BytesMut,
}
impl<S> Framed<S> {
pub fn new(stream: S) -> Self {
Self {
stream,
read_buf: BytesMut::with_capacity(INITIAL_CAPACITY),
write_buf: BytesMut::with_capacity(INITIAL_CAPACITY),
}
}
/// Get a shared reference to the underlying stream.
pub fn get_ref(&self) -> &S {
&self.stream
}
/// Deconstruct into the underlying stream and read buffer.
pub fn into_inner(self) -> (S, BytesMut) {
(self.stream, self.read_buf)
}
/// Return new Framed with stream type transformed by async f, for TLS
/// upgrade.
pub async fn map_stream<S2, E, F, Fut>(self, f: F) -> Result<Framed<S2>, E>
where
F: FnOnce(S) -> Fut,
Fut: Future<Output = Result<S2, E>>,
{
let stream = f(self.stream).await?;
Ok(Framed {
stream,
read_buf: self.read_buf,
write_buf: self.write_buf,
})
}
}
impl<S: AsyncRead + Unpin> Framed<S> {
pub async fn read_startup_message(
&mut self,
) -> Result<Option<FeStartupPacket>, ConnectionError> {
read_message(&mut self.stream, &mut self.read_buf, FeStartupPacket::parse).await
}
pub async fn read_message(&mut self) -> Result<Option<FeMessage>, ConnectionError> {
read_message(&mut self.stream, &mut self.read_buf, FeMessage::parse).await
}
}
impl<S: AsyncWrite + Unpin> Framed<S> {
/// Write next message to the output buffer; doesn't flush.
pub fn write_message(&mut self, msg: &BeMessage<'_>) -> Result<(), ProtocolError> {
BeMessage::write(&mut self.write_buf, msg)
}
/// Flush out the buffer. This function is cancellation safe: it can be
/// interrupted and flushing will be continued in the next call.
pub async fn flush(&mut self) -> Result<(), io::Error> {
flush(&mut self.stream, &mut self.write_buf).await
}
/// Flush out the buffer and shutdown the stream.
pub async fn shutdown(&mut self) -> Result<(), io::Error> {
shutdown(&mut self.stream, &mut self.write_buf).await
}
}
impl<S: AsyncRead + AsyncWrite + Unpin> Framed<S> {
/// Split into owned read and write parts. Beware of potential issues with
/// using halves in different tasks on TLS stream:
/// https://github.com/tokio-rs/tls/issues/40
pub fn split(self) -> (FramedReader<S>, FramedWriter<S>) {
let (read_half, write_half) = tokio::io::split(self.stream);
let reader = FramedReader {
stream: read_half,
read_buf: self.read_buf,
};
let writer = FramedWriter {
stream: write_half,
write_buf: self.write_buf,
};
(reader, writer)
}
/// Join read and write parts back.
pub fn unsplit(reader: FramedReader<S>, writer: FramedWriter<S>) -> Self {
Self {
stream: reader.stream.unsplit(writer.stream),
read_buf: reader.read_buf,
write_buf: writer.write_buf,
}
}
}
/// Read-only version of `Framed`.
pub struct FramedReader<S> {
stream: ReadHalf<S>,
read_buf: BytesMut,
}
impl<S: AsyncRead + Unpin> FramedReader<S> {
pub async fn read_message(&mut self) -> Result<Option<FeMessage>, ConnectionError> {
read_message(&mut self.stream, &mut self.read_buf, FeMessage::parse).await
}
}
/// Write-only version of `Framed`.
pub struct FramedWriter<S> {
stream: WriteHalf<S>,
write_buf: BytesMut,
}
impl<S: AsyncWrite + Unpin> FramedWriter<S> {
/// Write next message to the output buffer; doesn't flush.
pub fn write_message_noflush(&mut self, msg: &BeMessage<'_>) -> Result<(), ProtocolError> {
BeMessage::write(&mut self.write_buf, msg)
}
/// Flush out the buffer. This function is cancellation safe: it can be
/// interrupted and flushing will be continued in the next call.
pub async fn flush(&mut self) -> Result<(), io::Error> {
flush(&mut self.stream, &mut self.write_buf).await
}
/// Flush out the buffer and shutdown the stream.
pub async fn shutdown(&mut self) -> Result<(), io::Error> {
shutdown(&mut self.stream, &mut self.write_buf).await
}
}
/// Read next message from the stream. Returns Ok(None), if EOF happened and we
/// don't have remaining data in the buffer. This function is cancellation safe:
/// you can drop future which is not yet complete and finalize reading message
/// with the next call.
///
/// Parametrized to allow reading startup or usual message, having different
/// format.
async fn read_message<S: AsyncRead + Unpin, M, P>(
stream: &mut S,
read_buf: &mut BytesMut,
parse: P,
) -> Result<Option<M>, ConnectionError>
where
P: Fn(&mut BytesMut) -> Result<Option<M>, ProtocolError>,
{
loop {
if let Some(msg) = parse(read_buf)? {
return Ok(Some(msg));
}
// If we can't build a frame yet, try to read more data and try again.
// Make sure we've got room for at least one byte to read to ensure
// that we don't get a spurious 0 that looks like EOF.
read_buf.reserve(1);
if stream.read_buf(read_buf).await? == 0 {
if read_buf.has_remaining() {
return Err(io::Error::new(
ErrorKind::UnexpectedEof,
"EOF with unprocessed data in the buffer",
)
.into());
} else {
return Ok(None); // clean EOF
}
}
}
}
async fn flush<S: AsyncWrite + Unpin>(
stream: &mut S,
write_buf: &mut BytesMut,
) -> Result<(), io::Error> {
while write_buf.has_remaining() {
let bytes_written = stream.write(write_buf.chunk()).await?;
if bytes_written == 0 {
return Err(io::Error::new(
ErrorKind::WriteZero,
"failed to write message",
));
}
// The advanced part will be garbage collected, likely during shifting
// data left on next attempt to write to buffer when free space is not
// enough.
write_buf.advance(bytes_written);
}
write_buf.clear();
stream.flush().await
}
async fn shutdown<S: AsyncWrite + Unpin>(
stream: &mut S,
write_buf: &mut BytesMut,
) -> Result<(), io::Error> {
flush(stream, write_buf).await?;
stream.shutdown().await
}

View File

@@ -2,18 +2,24 @@
//! <https://www.postgresql.org/docs/devel/protocol-message-formats.html>
//! on message formats.
pub mod framed;
// Tools for calling certain async methods in sync contexts.
pub mod sync;
use byteorder::{BigEndian, ReadBytesExt};
use anyhow::{ensure, Context, Result};
use bytes::{Buf, BufMut, Bytes, BytesMut};
use postgres_protocol::PG_EPOCH;
use serde::{Deserialize, Serialize};
use std::{
borrow::Cow,
collections::HashMap,
fmt, io, str,
fmt,
future::Future,
io::{self, Cursor},
str,
time::{Duration, SystemTime},
};
use sync::{AsyncishRead, SyncFuture};
use tokio::io::AsyncReadExt;
use tracing::{trace, warn};
pub type Oid = u32;
@@ -25,6 +31,7 @@ pub const TEXT_OID: Oid = 25;
#[derive(Debug)]
pub enum FeMessage {
StartupPacket(FeStartupPacket),
// Simple query.
Query(Bytes),
// Extended query protocol.
@@ -184,205 +191,260 @@ pub struct FeExecuteMessage {
#[derive(Debug)]
pub struct FeCloseMessage;
/// An error occured while parsing or serializing raw stream into Postgres
/// messages.
#[derive(thiserror::Error, Debug)]
pub enum ProtocolError {
/// Invalid packet was received from the client (e.g. unexpected message
/// type or broken len).
#[error("Protocol error: {0}")]
Protocol(String),
/// Failed to parse or, (unlikely), serialize a protocol message.
#[error("Message parse error: {0}")]
BadMessage(String),
/// Retry a read on EINTR
///
/// This runs the enclosed expression, and if it returns
/// Err(io::ErrorKind::Interrupted), retries it.
macro_rules! retry_read {
( $x:expr ) => {
loop {
match $x {
Err(e) if e.kind() == io::ErrorKind::Interrupted => continue,
res => break res,
}
}
};
}
impl ProtocolError {
/// Proxy stream.rs uses only io::Error; provide it.
/// An error occured during connection being open.
#[derive(thiserror::Error, Debug)]
pub enum ConnectionError {
/// IO error during writing to or reading from the connection socket.
#[error("Socket IO error: {0}")]
Socket(std::io::Error),
/// Invalid packet was received from client
#[error("Protocol error: {0}")]
Protocol(String),
/// Failed to parse a protocol mesage
#[error("Message parse error: {0}")]
MessageParse(anyhow::Error),
}
impl From<anyhow::Error> for ConnectionError {
fn from(e: anyhow::Error) -> Self {
Self::MessageParse(e)
}
}
impl ConnectionError {
pub fn into_io_error(self) -> io::Error {
io::Error::new(io::ErrorKind::Other, self.to_string())
match self {
ConnectionError::Socket(io) => io,
other => io::Error::new(io::ErrorKind::Other, other.to_string()),
}
}
}
impl FeMessage {
/// Read and parse one message from the `buf` input buffer. If there is at
/// least one valid message, returns it, advancing `buf`; redundant copies
/// are avoided, as thanks to `bytes` crate ptrs in parsed message point
/// directly into the `buf` (processed data is garbage collected after
/// parsed message is dropped).
/// Read one message from the stream.
/// This function returns `Ok(None)` in case of EOF.
/// One way to handle this properly:
///
/// Returns None if `buf` doesn't contain enough data for a single message.
/// For efficiency, tries to reserve large enough space in `buf` for the
/// next message in this case to save the repeated calls.
/// ```
/// # use std::io;
/// # use pq_proto::FeMessage;
/// #
/// # fn process_message(msg: FeMessage) -> anyhow::Result<()> {
/// # Ok(())
/// # };
/// #
/// fn do_the_job(stream: &mut (impl io::Read + Unpin)) -> anyhow::Result<()> {
/// while let Some(msg) = FeMessage::read(stream)? {
/// process_message(msg)?;
/// }
///
/// Returns Error if message is malformed, the only possible ErrorKind is
/// InvalidInput.
//
// Inspired by rust-postgres Message::parse.
pub fn parse(buf: &mut BytesMut) -> Result<Option<FeMessage>, ProtocolError> {
// Every message contains message type byte and 4 bytes len; can't do
// much without them.
if buf.len() < 5 {
let to_read = 5 - buf.len();
buf.reserve(to_read);
return Ok(None);
}
/// Ok(())
/// }
/// ```
#[inline(never)]
pub fn read(
stream: &mut (impl io::Read + Unpin),
) -> Result<Option<FeMessage>, ConnectionError> {
Self::read_fut(&mut AsyncishRead(stream)).wait()
}
// We shouldn't advance `buf` as probably full message is not there yet,
// so can't directly use Bytes::get_u32 etc.
let tag = buf[0];
let len = (&buf[1..5]).read_u32::<BigEndian>().unwrap();
if len < 4 {
return Err(ProtocolError::Protocol(format!(
"invalid message length {}",
len
)));
}
/// Read one message from the stream.
/// See documentation for `Self::read`.
pub fn read_fut<Reader>(
stream: &mut Reader,
) -> SyncFuture<Reader, impl Future<Output = Result<Option<FeMessage>, ConnectionError>> + '_>
where
Reader: tokio::io::AsyncRead + Unpin,
{
// We return a Future that's sync (has a `wait` method) if and only if the provided stream is SyncProof.
// SyncFuture contract: we are only allowed to await on sync-proof futures, the AsyncRead and
// AsyncReadExt methods of the stream.
SyncFuture::new(async move {
// Each libpq message begins with a message type byte, followed by message length
// If the client closes the connection, return None. But if the client closes the
// connection in the middle of a message, we will return an error.
let tag = match retry_read!(stream.read_u8().await) {
Ok(b) => b,
Err(e) if e.kind() == io::ErrorKind::UnexpectedEof => return Ok(None),
Err(e) => return Err(ConnectionError::Socket(e)),
};
// length field includes itself, but not message type.
let total_len = len as usize + 1;
if buf.len() < total_len {
// Don't have full message yet.
let to_read = total_len - buf.len();
buf.reserve(to_read);
return Ok(None);
}
// The message length includes itself, so it better be at least 4.
let len = retry_read!(stream.read_u32().await)
.map_err(ConnectionError::Socket)?
.checked_sub(4)
.ok_or_else(|| ConnectionError::Protocol("invalid message length".to_string()))?;
// got the message, advance buffer
let mut msg = buf.split_to(total_len).freeze();
msg.advance(5); // consume message type and len
let body = {
let mut buffer = vec![0u8; len as usize];
stream
.read_exact(&mut buffer)
.await
.map_err(ConnectionError::Socket)?;
Bytes::from(buffer)
};
match tag {
b'Q' => Ok(Some(FeMessage::Query(msg))),
b'P' => Ok(Some(FeParseMessage::parse(msg)?)),
b'D' => Ok(Some(FeDescribeMessage::parse(msg)?)),
b'E' => Ok(Some(FeExecuteMessage::parse(msg)?)),
b'B' => Ok(Some(FeBindMessage::parse(msg)?)),
b'C' => Ok(Some(FeCloseMessage::parse(msg)?)),
b'S' => Ok(Some(FeMessage::Sync)),
b'X' => Ok(Some(FeMessage::Terminate)),
b'd' => Ok(Some(FeMessage::CopyData(msg))),
b'c' => Ok(Some(FeMessage::CopyDone)),
b'f' => Ok(Some(FeMessage::CopyFail)),
b'p' => Ok(Some(FeMessage::PasswordMessage(msg))),
tag => Err(ProtocolError::Protocol(format!(
"unknown message tag: {tag},'{msg:?}'"
))),
}
match tag {
b'Q' => Ok(Some(FeMessage::Query(body))),
b'P' => Ok(Some(FeParseMessage::parse(body)?)),
b'D' => Ok(Some(FeDescribeMessage::parse(body)?)),
b'E' => Ok(Some(FeExecuteMessage::parse(body)?)),
b'B' => Ok(Some(FeBindMessage::parse(body)?)),
b'C' => Ok(Some(FeCloseMessage::parse(body)?)),
b'S' => Ok(Some(FeMessage::Sync)),
b'X' => Ok(Some(FeMessage::Terminate)),
b'd' => Ok(Some(FeMessage::CopyData(body))),
b'c' => Ok(Some(FeMessage::CopyDone)),
b'f' => Ok(Some(FeMessage::CopyFail)),
b'p' => Ok(Some(FeMessage::PasswordMessage(body))),
tag => {
return Err(ConnectionError::Protocol(format!(
"unknown message tag: {tag},'{body:?}'"
)))
}
}
})
}
}
impl FeStartupPacket {
/// Read and parse startup message from the `buf` input buffer. It is
/// different from [`FeMessage::parse`] because startup messages don't have
/// message type byte; otherwise, its comments apply.
pub fn parse(buf: &mut BytesMut) -> Result<Option<FeStartupPacket>, ProtocolError> {
/// Read startup message from the stream.
// XXX: It's tempting yet undesirable to accept `stream` by value,
// since such a change will cause user-supplied &mut references to be consumed
pub fn read(
stream: &mut (impl io::Read + Unpin),
) -> Result<Option<FeMessage>, ConnectionError> {
Self::read_fut(&mut AsyncishRead(stream)).wait()
}
/// Read startup message from the stream.
// XXX: It's tempting yet undesirable to accept `stream` by value,
// since such a change will cause user-supplied &mut references to be consumed
pub fn read_fut<Reader>(
stream: &mut Reader,
) -> SyncFuture<Reader, impl Future<Output = Result<Option<FeMessage>, ConnectionError>> + '_>
where
Reader: tokio::io::AsyncRead + Unpin,
{
const MAX_STARTUP_PACKET_LENGTH: usize = 10000;
const RESERVED_INVALID_MAJOR_VERSION: u32 = 1234;
const CANCEL_REQUEST_CODE: u32 = 5678;
const NEGOTIATE_SSL_CODE: u32 = 5679;
const NEGOTIATE_GSS_CODE: u32 = 5680;
// need at least 4 bytes with packet len
if buf.len() < 4 {
let to_read = 4 - buf.len();
buf.reserve(to_read);
return Ok(None);
}
SyncFuture::new(async move {
// Read length. If the connection is closed before reading anything (or before
// reading 4 bytes, to be precise), return None to indicate that the connection
// was closed. This matches the PostgreSQL server's behavior, which avoids noise
// in the log if the client opens connection but closes it immediately.
let len = match retry_read!(stream.read_u32().await) {
Ok(len) => len as usize,
Err(e) if e.kind() == io::ErrorKind::UnexpectedEof => return Ok(None),
Err(e) => return Err(ConnectionError::Socket(e)),
};
// We shouldn't advance `buf` as probably full message is not there yet,
// so can't directly use Bytes::get_u32 etc.
let len = (&buf[0..4]).read_u32::<BigEndian>().unwrap() as usize;
if len < 4 || len > MAX_STARTUP_PACKET_LENGTH {
return Err(ProtocolError::Protocol(format!(
"invalid startup packet message length {}",
len
)));
}
if buf.len() < len {
// Don't have full message yet.
let to_read = len - buf.len();
buf.reserve(to_read);
return Ok(None);
}
// got the message, advance buffer
let mut msg = buf.split_to(len).freeze();
msg.advance(4); // consume len
let request_code = msg.get_u32();
let req_hi = request_code >> 16;
let req_lo = request_code & ((1 << 16) - 1);
// StartupMessage, CancelRequest, SSLRequest etc are differentiated by request code.
let message = match (req_hi, req_lo) {
(RESERVED_INVALID_MAJOR_VERSION, CANCEL_REQUEST_CODE) => {
if msg.remaining() != 8 {
return Err(ProtocolError::BadMessage(
"CancelRequest message is malformed, backend PID / secret key missing"
.to_owned(),
));
}
FeStartupPacket::CancelRequest(CancelKeyData {
backend_pid: msg.get_i32(),
cancel_key: msg.get_i32(),
})
}
(RESERVED_INVALID_MAJOR_VERSION, NEGOTIATE_SSL_CODE) => {
// Requested upgrade to SSL (aka TLS)
FeStartupPacket::SslRequest
}
(RESERVED_INVALID_MAJOR_VERSION, NEGOTIATE_GSS_CODE) => {
// Requested upgrade to GSSAPI
FeStartupPacket::GssEncRequest
}
(RESERVED_INVALID_MAJOR_VERSION, unrecognized_code) => {
return Err(ProtocolError::Protocol(format!(
"Unrecognized request code {unrecognized_code}"
#[allow(clippy::manual_range_contains)]
if len < 4 || len > MAX_STARTUP_PACKET_LENGTH {
return Err(ConnectionError::Protocol(format!(
"invalid message length {len}"
)));
}
// TODO bail if protocol major_version is not 3?
(major_version, minor_version) => {
// StartupMessage
// Parse pairs of null-terminated strings (key, value).
// See `postgres: ProcessStartupPacket, build_startup_packet`.
let mut tokens = str::from_utf8(&msg)
.map_err(|_e| {
ProtocolError::BadMessage("StartupMessage params: invalid utf-8".to_owned())
})?
.strip_suffix('\0') // drop packet's own null
.ok_or_else(|| {
ProtocolError::Protocol(
"StartupMessage params: missing null terminator".to_string(),
)
})?
.split_terminator('\0');
let request_code =
retry_read!(stream.read_u32().await).map_err(ConnectionError::Socket)?;
let mut params = HashMap::new();
while let Some(name) = tokens.next() {
let value = tokens.next().ok_or_else(|| {
ProtocolError::Protocol(
"StartupMessage params: key without value".to_string(),
)
})?;
// the rest of startup packet are params
let params_len = len - 8;
let mut params_bytes = vec![0u8; params_len];
stream
.read_exact(params_bytes.as_mut())
.await
.map_err(ConnectionError::Socket)?;
params.insert(name.to_owned(), value.to_owned());
// Parse params depending on request code
let req_hi = request_code >> 16;
let req_lo = request_code & ((1 << 16) - 1);
let message = match (req_hi, req_lo) {
(RESERVED_INVALID_MAJOR_VERSION, CANCEL_REQUEST_CODE) => {
if params_len != 8 {
return Err(ConnectionError::Protocol(
"expected 8 bytes for CancelRequest params".to_string(),
));
}
let mut cursor = Cursor::new(params_bytes);
FeStartupPacket::CancelRequest(CancelKeyData {
backend_pid: cursor.read_i32().await.map_err(ConnectionError::Socket)?,
cancel_key: cursor.read_i32().await.map_err(ConnectionError::Socket)?,
})
}
FeStartupPacket::StartupMessage {
major_version,
minor_version,
params: StartupMessageParams { params },
(RESERVED_INVALID_MAJOR_VERSION, NEGOTIATE_SSL_CODE) => {
// Requested upgrade to SSL (aka TLS)
FeStartupPacket::SslRequest
}
}
};
Ok(Some(message))
(RESERVED_INVALID_MAJOR_VERSION, NEGOTIATE_GSS_CODE) => {
// Requested upgrade to GSSAPI
FeStartupPacket::GssEncRequest
}
(RESERVED_INVALID_MAJOR_VERSION, unrecognized_code) => {
return Err(ConnectionError::Protocol(format!(
"Unrecognized request code {unrecognized_code}"
)));
}
// TODO bail if protocol major_version is not 3?
(major_version, minor_version) => {
// Parse pairs of null-terminated strings (key, value).
// See `postgres: ProcessStartupPacket, build_startup_packet`.
let mut tokens = str::from_utf8(&params_bytes)
.context("StartupMessage params: invalid utf-8")?
.strip_suffix('\0') // drop packet's own null
.ok_or_else(|| {
ConnectionError::Protocol(
"StartupMessage params: missing null terminator".to_string(),
)
})?
.split_terminator('\0');
let mut params = HashMap::new();
while let Some(name) = tokens.next() {
let value = tokens.next().ok_or_else(|| {
ConnectionError::Protocol(
"StartupMessage params: key without value".to_string(),
)
})?;
params.insert(name.to_owned(), value.to_owned());
}
FeStartupPacket::StartupMessage {
major_version,
minor_version,
params: StartupMessageParams { params },
}
}
};
Ok(Some(FeMessage::StartupPacket(message)))
})
}
}
impl FeParseMessage {
fn parse(mut buf: Bytes) -> Result<FeMessage, ProtocolError> {
fn parse(mut buf: Bytes) -> anyhow::Result<FeMessage> {
// FIXME: the rust-postgres driver uses a named prepared statement
// for copy_out(). We're not prepared to handle that correctly. For
// now, just ignore the statement name, assuming that the client never
@@ -390,82 +452,55 @@ impl FeParseMessage {
let _pstmt_name = read_cstr(&mut buf)?;
let query_string = read_cstr(&mut buf)?;
if buf.remaining() < 2 {
return Err(ProtocolError::BadMessage(
"Parse message is malformed, nparams missing".to_string(),
));
}
let nparams = buf.get_i16();
if nparams != 0 {
return Err(ProtocolError::BadMessage(
"query params not implemented".to_string(),
));
}
ensure!(nparams == 0, "query params not implemented");
Ok(FeMessage::Parse(FeParseMessage { query_string }))
}
}
impl FeDescribeMessage {
fn parse(mut buf: Bytes) -> Result<FeMessage, ProtocolError> {
fn parse(mut buf: Bytes) -> anyhow::Result<FeMessage> {
let kind = buf.get_u8();
let _pstmt_name = read_cstr(&mut buf)?;
// FIXME: see FeParseMessage::parse
if kind != b'S' {
return Err(ProtocolError::BadMessage(
"only prepared statemement Describe is implemented".to_string(),
));
}
ensure!(
kind == b'S',
"only prepared statemement Describe is implemented"
);
Ok(FeMessage::Describe(FeDescribeMessage { kind }))
}
}
impl FeExecuteMessage {
fn parse(mut buf: Bytes) -> Result<FeMessage, ProtocolError> {
fn parse(mut buf: Bytes) -> anyhow::Result<FeMessage> {
let portal_name = read_cstr(&mut buf)?;
if buf.remaining() < 4 {
return Err(ProtocolError::BadMessage(
"FeExecuteMessage message is malformed, maxrows missing".to_string(),
));
}
let maxrows = buf.get_i32();
if !portal_name.is_empty() {
return Err(ProtocolError::BadMessage(
"named portals not implemented".to_string(),
));
}
if maxrows != 0 {
return Err(ProtocolError::BadMessage(
"row limit in Execute message not implemented".to_string(),
));
}
ensure!(portal_name.is_empty(), "named portals not implemented");
ensure!(maxrows == 0, "row limit in Execute message not implemented");
Ok(FeMessage::Execute(FeExecuteMessage { maxrows }))
}
}
impl FeBindMessage {
fn parse(mut buf: Bytes) -> Result<FeMessage, ProtocolError> {
fn parse(mut buf: Bytes) -> anyhow::Result<FeMessage> {
let portal_name = read_cstr(&mut buf)?;
let _pstmt_name = read_cstr(&mut buf)?;
// FIXME: see FeParseMessage::parse
if !portal_name.is_empty() {
return Err(ProtocolError::BadMessage(
"named portals not implemented".to_string(),
));
}
ensure!(portal_name.is_empty(), "named portals not implemented");
Ok(FeMessage::Bind(FeBindMessage))
}
}
impl FeCloseMessage {
fn parse(mut buf: Bytes) -> Result<FeMessage, ProtocolError> {
fn parse(mut buf: Bytes) -> anyhow::Result<FeMessage> {
let _kind = buf.get_u8();
let _pstmt_or_portal_name = read_cstr(&mut buf)?;
@@ -494,7 +529,6 @@ pub enum BeMessage<'a> {
CloseComplete,
// None means column is NULL
DataRow(&'a [Option<&'a [u8]>]),
// None errcode means internal_error will be sent.
ErrorResponse(&'a str, Option<&'a [u8; 5]>),
/// Single byte - used in response to SSLRequest/GSSENCRequest.
EncryptionResponse(bool),
@@ -525,11 +559,6 @@ impl<'a> BeMessage<'a> {
value: b"UTF8",
};
pub const INTEGER_DATETIMES: Self = Self::ParameterStatus {
name: b"integer_datetimes",
value: b"on",
};
/// Build a [`BeMessage::ParameterStatus`] holding the server version.
pub fn server_version(version: &'a str) -> Self {
Self::ParameterStatus {
@@ -608,7 +637,7 @@ impl RowDescriptor<'_> {
#[derive(Debug)]
pub struct XLogDataBody<'a> {
pub wal_start: u64,
pub wal_end: u64, // current end of WAL on the server
pub wal_end: u64,
pub timestamp: i64,
pub data: &'a [u8],
}
@@ -648,11 +677,12 @@ fn write_body<R>(buf: &mut BytesMut, f: impl FnOnce(&mut BytesMut) -> R) -> R {
}
/// Safe write of s into buf as cstring (String in the protocol).
fn write_cstr(s: impl AsRef<[u8]>, buf: &mut BytesMut) -> Result<(), ProtocolError> {
fn write_cstr(s: impl AsRef<[u8]>, buf: &mut BytesMut) -> io::Result<()> {
let bytes = s.as_ref();
if bytes.contains(&0) {
return Err(ProtocolError::BadMessage(
"string contains embedded null".to_owned(),
return Err(io::Error::new(
io::ErrorKind::InvalidInput,
"string contains embedded null",
));
}
buf.put_slice(bytes);
@@ -660,27 +690,22 @@ fn write_cstr(s: impl AsRef<[u8]>, buf: &mut BytesMut) -> Result<(), ProtocolErr
Ok(())
}
/// Read cstring from buf, advancing it.
fn read_cstr(buf: &mut Bytes) -> Result<Bytes, ProtocolError> {
let pos = buf
.iter()
.position(|x| *x == 0)
.ok_or_else(|| ProtocolError::BadMessage("missing cstring terminator".to_owned()))?;
let result = buf.split_to(pos);
fn read_cstr(buf: &mut Bytes) -> anyhow::Result<Bytes> {
let pos = buf.iter().position(|x| *x == 0);
let result = buf.split_to(pos.context("missing terminator")?);
buf.advance(1); // drop the null terminator
Ok(result)
}
pub const SQLSTATE_INTERNAL_ERROR: &[u8; 5] = b"XX000";
pub const SQLSTATE_SUCCESSFUL_COMPLETION: &[u8; 5] = b"00000";
impl<'a> BeMessage<'a> {
/// Serialize `message` to the given `buf`.
/// Apart from smart memory managemet, BytesMut is good here as msg len
/// precedes its body and it is handy to write it down first and then fill
/// the length. With Write we would have to either calc it manually or have
/// one more buffer.
pub fn write(buf: &mut BytesMut, message: &BeMessage) -> Result<(), ProtocolError> {
/// Write message to the given buf.
// Unlike the reading side, we use BytesMut
// here as msg len precedes its body and it is handy to write it down first
// and then fill the length. With Write we would have to either calc it
// manually or have one more buffer.
pub fn write(buf: &mut BytesMut, message: &BeMessage) -> io::Result<()> {
match message {
BeMessage::AuthenticationOk => {
buf.put_u8(b'R');
@@ -725,7 +750,7 @@ impl<'a> BeMessage<'a> {
buf.put_slice(extra);
}
}
Ok(())
Ok::<_, io::Error>(())
})?;
}
@@ -829,7 +854,7 @@ impl<'a> BeMessage<'a> {
write_cstr(error_msg, buf)?;
buf.put_u8(0); // terminator
Ok(())
Ok::<_, io::Error>(())
})?;
}
@@ -852,7 +877,7 @@ impl<'a> BeMessage<'a> {
write_cstr(error_msg.as_bytes(), buf)?;
buf.put_u8(0); // terminator
Ok(())
Ok::<_, io::Error>(())
})?;
}
@@ -907,7 +932,7 @@ impl<'a> BeMessage<'a> {
buf.put_i32(-1); /* typmod */
buf.put_i16(0); /* format code */
}
Ok(())
Ok::<_, io::Error>(())
})?;
}
@@ -974,7 +999,7 @@ impl ReplicationFeedback {
// null-terminated string - key,
// uint32 - value length in bytes
// value itself
pub fn serialize(&self, buf: &mut BytesMut) {
pub fn serialize(&self, buf: &mut BytesMut) -> Result<()> {
buf.put_u8(REPLICATION_FEEDBACK_FIELDS_NUMBER); // # of keys
buf.put_slice(b"current_timeline_size\0");
buf.put_i32(8);
@@ -999,6 +1024,7 @@ impl ReplicationFeedback {
buf.put_slice(b"ps_replytime\0");
buf.put_i32(8);
buf.put_i64(timestamp);
Ok(())
}
// Deserialize ReplicationFeedback message
@@ -1066,7 +1092,7 @@ mod tests {
// because it is rounded up to microseconds during serialization.
rf.ps_replytime = *PG_EPOCH + Duration::from_secs(100_000_000);
let mut data = BytesMut::new();
rf.serialize(&mut data);
rf.serialize(&mut data).unwrap();
let rf_parsed = ReplicationFeedback::parse(data.freeze());
assert_eq!(rf, rf_parsed);
@@ -1081,7 +1107,7 @@ mod tests {
// because it is rounded up to microseconds during serialization.
rf.ps_replytime = *PG_EPOCH + Duration::from_secs(100_000_000);
let mut data = BytesMut::new();
rf.serialize(&mut data);
rf.serialize(&mut data).unwrap();
// Add an extra field to the buffer and adjust number of keys
if let Some(first) = data.first_mut() {
@@ -1123,6 +1149,15 @@ mod tests {
let params = make_params("foo\\ bar \\ \\\\ baz\\ lol");
assert_eq!(split_options(&params), ["foo bar", " \\", "baz ", "lol"]);
}
// Make sure that `read` is sync/async callable
async fn _assert(stream: &mut (impl tokio::io::AsyncRead + Unpin)) {
let _ = FeMessage::read(&mut [].as_ref());
let _ = FeMessage::read_fut(stream).await;
let _ = FeStartupPacket::read(&mut [].as_ref());
let _ = FeStartupPacket::read_fut(stream).await;
}
}
fn terminate_code(code: &[u8; 5]) -> [u8; 6] {

179
libs/pq_proto/src/sync.rs Normal file
View File

@@ -0,0 +1,179 @@
use pin_project_lite::pin_project;
use std::future::Future;
use std::marker::PhantomData;
use std::pin::Pin;
use std::{io, task};
pin_project! {
/// We use this future to mark certain methods
/// as callable in both sync and async modes.
#[repr(transparent)]
pub struct SyncFuture<S, T: Future> {
#[pin]
inner: T,
_marker: PhantomData<S>,
}
}
/// This wrapper lets us synchronously wait for inner future's completion
/// (see [`SyncFuture::wait`]) **provided that `S` implements [`SyncProof`]**.
/// For instance, `S` may be substituted with types implementing
/// [`tokio::io::AsyncRead`], but it's not the only viable option.
impl<S, T: Future> SyncFuture<S, T> {
/// NOTE: caller should carefully pick a type for `S`,
/// because we don't want to enable [`SyncFuture::wait`] when
/// it's in fact impossible to run the future synchronously.
/// Violation of this contract will not cause UB, but
/// panics and async event loop freezes won't please you.
///
/// Example:
///
/// ```
/// # use pq_proto::sync::SyncFuture;
/// # use std::future::Future;
/// # use tokio::io::AsyncReadExt;
/// #
/// // Parse a pair of numbers from a stream
/// pub fn parse_pair<Reader>(
/// stream: &mut Reader,
/// ) -> SyncFuture<Reader, impl Future<Output = anyhow::Result<(u32, u64)>> + '_>
/// where
/// Reader: tokio::io::AsyncRead + Unpin,
/// {
/// // If `Reader` is a `SyncProof`, this will give caller
/// // an opportunity to use `SyncFuture::wait`, because
/// // `.await` will always result in `Poll::Ready`.
/// SyncFuture::new(async move {
/// let x = stream.read_u32().await?;
/// let y = stream.read_u64().await?;
/// Ok((x, y))
/// })
/// }
/// ```
pub fn new(inner: T) -> Self {
Self {
inner,
_marker: PhantomData,
}
}
}
impl<S, T: Future> Future for SyncFuture<S, T> {
type Output = T::Output;
/// In async code, [`SyncFuture`] behaves like a regular wrapper.
#[inline(always)]
fn poll(self: Pin<&mut Self>, cx: &mut task::Context<'_>) -> task::Poll<Self::Output> {
self.project().inner.poll(cx)
}
}
/// Postulates that we can call [`SyncFuture::wait`].
/// If implementer is also a [`Future`], it should always
/// return [`task::Poll::Ready`] from [`Future::poll`].
///
/// Each implementation should document which futures
/// specifically are being declared sync-proof.
pub trait SyncPostulate {}
impl<T: SyncPostulate> SyncPostulate for &T {}
impl<T: SyncPostulate> SyncPostulate for &mut T {}
impl<P: SyncPostulate, T: Future> SyncFuture<P, T> {
/// Synchronously wait for future completion.
pub fn wait(mut self) -> T::Output {
const RAW_WAKER: task::RawWaker = task::RawWaker::new(
std::ptr::null(),
&task::RawWakerVTable::new(
|_| RAW_WAKER,
|_| panic!("SyncFuture: failed to wake"),
|_| panic!("SyncFuture: failed to wake by ref"),
|_| { /* drop is no-op */ },
),
);
// SAFETY: We never move `self` during this call;
// furthermore, it will be dropped in the end regardless of panics
let this = unsafe { Pin::new_unchecked(&mut self) };
// SAFETY: This waker doesn't do anything apart from panicking
let waker = unsafe { task::Waker::from_raw(RAW_WAKER) };
let context = &mut task::Context::from_waker(&waker);
match this.poll(context) {
task::Poll::Ready(res) => res,
_ => panic!("SyncFuture: unexpected pending!"),
}
}
}
/// This wrapper turns any [`std::io::Read`] into a blocking [`tokio::io::AsyncRead`],
/// which lets us abstract over sync & async readers in methods returning [`SyncFuture`].
/// NOTE: you **should not** use this in async code.
#[repr(transparent)]
pub struct AsyncishRead<T: io::Read + Unpin>(pub T);
/// This lets us call [`SyncFuture<AsyncishRead<_>, _>::wait`],
/// and allows the future to await on any of the [`AsyncRead`]
/// and [`AsyncReadExt`] methods on `AsyncishRead`.
impl<T: io::Read + Unpin> SyncPostulate for AsyncishRead<T> {}
impl<T: io::Read + Unpin> tokio::io::AsyncRead for AsyncishRead<T> {
#[inline(always)]
fn poll_read(
mut self: Pin<&mut Self>,
_cx: &mut task::Context<'_>,
buf: &mut tokio::io::ReadBuf<'_>,
) -> task::Poll<io::Result<()>> {
task::Poll::Ready(
// `Read::read` will block, meaning we don't need a real event loop!
self.0
.read(buf.initialize_unfilled())
.map(|sz| buf.advance(sz)),
)
}
}
#[cfg(test)]
mod tests {
use super::*;
use tokio::io::{AsyncReadExt, AsyncWriteExt};
// async helper(stream: &mut impl AsyncRead) -> io::Result<u32>
fn bytes_add<Reader>(
stream: &mut Reader,
) -> SyncFuture<Reader, impl Future<Output = io::Result<u32>> + '_>
where
Reader: tokio::io::AsyncRead + Unpin,
{
SyncFuture::new(async move {
let a = stream.read_u32().await?;
let b = stream.read_u32().await?;
Ok(a + b)
})
}
#[test]
fn test_sync() {
let bytes = [100u32.to_be_bytes(), 200u32.to_be_bytes()].concat();
let res = bytes_add(&mut AsyncishRead(&mut &bytes[..]))
.wait()
.unwrap();
assert_eq!(res, 300);
}
// We need a single-threaded executor for this test
#[tokio::test(flavor = "current_thread")]
async fn test_async() {
let (mut tx, mut rx) = tokio::net::UnixStream::pair().unwrap();
let write = async move {
tx.write_u32(100).await?;
tx.write_u32(200).await?;
Ok(())
};
let (res, ()) = tokio::try_join!(bytes_add(&mut rx), write).unwrap();
assert_eq!(res, 300);
}
}

View File

@@ -111,7 +111,7 @@ pub trait RemoteStorage: Send + Sync + 'static {
}
pub struct Download {
pub download_stream: Pin<Box<dyn io::AsyncRead + Unpin + Send + Sync>>,
pub download_stream: Pin<Box<dyn io::AsyncRead + Unpin + Send>>,
/// Extra key-value data, associated with the current remote file.
pub metadata: Option<StorageMetadata>,
}

View File

@@ -12,37 +12,42 @@ anyhow.workspace = true
bincode.workspace = true
bytes.workspace = true
heapless.workspace = true
hex = { workspace = true, features = ["serde"] }
hyper = { workspace = true, features = ["full"] }
futures = { workspace = true}
jsonwebtoken.workspace = true
nix.workspace = true
once_cell.workspace = true
pin-project-lite.workspace = true
routerify.workspace = true
serde.workspace = true
serde_json.workspace = true
signal-hook.workspace = true
thiserror.workspace = true
tokio.workspace = true
tokio-rustls.workspace = true
tracing.workspace = true
tracing-subscriber = { workspace = true, features = ["json"] }
nix.workspace = true
signal-hook.workspace = true
rand.workspace = true
jsonwebtoken.workspace = true
hex = { workspace = true, features = ["serde"] }
rustls.workspace = true
rustls-split.workspace = true
git-version.workspace = true
serde_with.workspace = true
once_cell.workspace = true
strum.workspace = true
strum_macros.workspace = true
url.workspace = true
uuid = { version = "1.2", features = ["v4", "serde"] }
metrics.workspace = true
workspace_hack.workspace = true
pq_proto.workspace = true
workspace_hack.workspace = true
url.workspace = true
uuid = { version = "1.2", features = ["v4", "serde"] }
[dev-dependencies]
byteorder.workspace = true
bytes.workspace = true
criterion.workspace = true
hex-literal.workspace = true
tempfile.workspace = true
criterion.workspace = true
rustls-pemfile.workspace = true
[[bench]]
name = "benchmarks"

View File

@@ -1,4 +1,7 @@
// For details about authentication see docs/authentication.md
//
// TODO: use ed25519 keys
// Relevant issue: https://github.com/Keats/jsonwebtoken/issues/162
use serde;
use std::fs;
@@ -13,10 +16,9 @@ use serde_with::{serde_as, DisplayFromStr};
use crate::id::TenantId;
/// Algorithm to use. We require EdDSA.
const STORAGE_TOKEN_ALGORITHM: Algorithm = Algorithm::EdDSA;
const JWT_ALGORITHM: Algorithm = Algorithm::RS256;
#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
#[derive(Debug, Serialize, Deserialize, Clone)]
#[serde(rename_all = "lowercase")]
pub enum Scope {
// Provides access to all data for a specific tenant (specified in `struct Claims` below)
@@ -31,9 +33,8 @@ pub enum Scope {
SafekeeperData,
}
/// JWT payload. See docs/authentication.md for the format
#[serde_as]
#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
#[derive(Debug, Serialize, Deserialize, Clone)]
pub struct Claims {
#[serde(default)]
#[serde_as(as = "Option<DisplayFromStr>")]
@@ -54,8 +55,7 @@ pub struct JwtAuth {
impl JwtAuth {
pub fn new(decoding_key: DecodingKey) -> Self {
let mut validation = Validation::default();
validation.algorithms = vec![STORAGE_TOKEN_ALGORITHM];
let mut validation = Validation::new(JWT_ALGORITHM);
// The default 'required_spec_claims' is 'exp'. But we don't want to require
// expiration.
validation.required_spec_claims = [].into();
@@ -67,7 +67,7 @@ impl JwtAuth {
pub fn from_key_path(key_path: &Path) -> Result<Self> {
let public_key = fs::read(key_path)?;
Ok(Self::new(DecodingKey::from_ed_pem(&public_key)?))
Ok(Self::new(DecodingKey::from_rsa_pem(&public_key)?))
}
pub fn decode(&self, token: &str) -> Result<TokenData<Claims>> {
@@ -85,75 +85,6 @@ impl std::fmt::Debug for JwtAuth {
// this function is used only for testing purposes in CLI e g generate tokens during init
pub fn encode_from_key_file(claims: &Claims, key_data: &[u8]) -> Result<String> {
let key = EncodingKey::from_ed_pem(key_data)?;
Ok(encode(&Header::new(STORAGE_TOKEN_ALGORITHM), claims, &key)?)
}
#[cfg(test)]
mod tests {
use super::*;
use std::str::FromStr;
// Generated with:
//
// openssl genpkey -algorithm ed25519 -out ed25519-priv.pem
// openssl pkey -in ed25519-priv.pem -pubout -out ed25519-pub.pem
const TEST_PUB_KEY_ED25519: &[u8] = br#"
-----BEGIN PUBLIC KEY-----
MCowBQYDK2VwAyEARYwaNBayR+eGI0iXB4s3QxE3Nl2g1iWbr6KtLWeVD/w=
-----END PUBLIC KEY-----
"#;
const TEST_PRIV_KEY_ED25519: &[u8] = br#"
-----BEGIN PRIVATE KEY-----
MC4CAQAwBQYDK2VwBCIEID/Drmc1AA6U/znNRWpF3zEGegOATQxfkdWxitcOMsIH
-----END PRIVATE KEY-----
"#;
#[test]
fn test_decode() -> Result<(), anyhow::Error> {
let expected_claims = Claims {
tenant_id: Some(TenantId::from_str("3d1f7595b468230304e0b73cecbcb081")?),
scope: Scope::Tenant,
};
// A test token containing the following payload, signed using TEST_PRIV_KEY_ED25519:
//
// ```
// {
// "scope": "tenant",
// "tenant_id": "3d1f7595b468230304e0b73cecbcb081",
// "iss": "neon.controlplane",
// "exp": 1709200879,
// "iat": 1678442479
// }
// ```
//
let encoded_eddsa = "eyJhbGciOiJFZERTQSIsInR5cCI6IkpXVCJ9.eyJzY29wZSI6InRlbmFudCIsInRlbmFudF9pZCI6IjNkMWY3NTk1YjQ2ODIzMDMwNGUwYjczY2VjYmNiMDgxIiwiaXNzIjoibmVvbi5jb250cm9scGxhbmUiLCJleHAiOjE3MDkyMDA4NzksImlhdCI6MTY3ODQ0MjQ3OX0.U3eA8j-uU-JnhzeO3EDHRuXLwkAUFCPxtGHEgw6p7Ccc3YRbFs2tmCdbD9PZEXP-XsxSeBQi1FY0YPcT3NXADw";
// Check it can be validated with the public key
let auth = JwtAuth::new(DecodingKey::from_ed_pem(TEST_PUB_KEY_ED25519)?);
let claims_from_token = auth.decode(encoded_eddsa)?.claims;
assert_eq!(claims_from_token, expected_claims);
Ok(())
}
#[test]
fn test_encode() -> Result<(), anyhow::Error> {
let claims = Claims {
tenant_id: Some(TenantId::from_str("3d1f7595b468230304e0b73cecbcb081")?),
scope: Scope::Tenant,
};
let encoded = encode_from_key_file(&claims, TEST_PRIV_KEY_ED25519)?;
// decode it back
let auth = JwtAuth::new(DecodingKey::from_ed_pem(TEST_PUB_KEY_ED25519)?);
let decoded = auth.decode(&encoded)?;
assert_eq!(decoded.claims, claims);
Ok(())
}
let key = EncodingKey::from_rsa_pem(key_data)?;
Ok(encode(&Header::new(JWT_ALGORITHM), claims, &key)?)
}

View File

@@ -11,7 +11,7 @@ where
P: AsRef<Path>,
{
fn is_empty_dir(&self) -> io::Result<bool> {
Ok(fs::read_dir(self)?.next().is_none())
Ok(fs::read_dir(self)?.into_iter().next().is_none())
}
}

View File

@@ -3,14 +3,14 @@ use crate::http::error;
use anyhow::{anyhow, Context};
use hyper::header::{HeaderName, AUTHORIZATION};
use hyper::http::HeaderValue;
use hyper::Method;
use hyper::{header::CONTENT_TYPE, Body, Request, Response, Server};
use hyper::{Method, StatusCode};
use metrics::{register_int_counter, Encoder, IntCounter, TextEncoder};
use once_cell::sync::Lazy;
use routerify::ext::RequestExt;
use routerify::{Middleware, RequestInfo, Router, RouterBuilder, RouterService};
use tokio::task::JoinError;
use tracing::{self, debug, info, info_span, warn, Instrument};
use tracing;
use std::future::Future;
use std::net::TcpListener;
@@ -32,77 +32,31 @@ static X_REQUEST_ID_HEADER: HeaderName = HeaderName::from_static(X_REQUEST_ID_HE
#[derive(Debug, Default, Clone)]
struct RequestId(String);
/// Adds a tracing info_span! instrumentation around the handler events,
/// logs the request start and end events for non-GET requests and non-200 responses.
///
/// Use this to distinguish between logs of different HTTP requests: every request handler wrapped
/// in this type will get request info logged in the wrapping span, including the unique request ID.
///
/// There could be other ways to implement similar functionality:
///
/// * procmacros placed on top of all handler methods
/// With all the drawbacks of procmacros, brings no difference implementation-wise,
/// and little code reduction compared to the existing approach.
///
/// * Another `TraitExt` with e.g. the `get_with_span`, `post_with_span` methods to do similar logic,
/// implemented for [`RouterBuilder`].
/// Could be simpler, but we don't want to depend on [`routerify`] more, targeting to use other library later.
///
/// * In theory, a span guard could've been created in a pre-request middleware and placed into a global collection, to be dropped
/// later, in a post-response middleware.
/// Due to suspendable nature of the futures, would give contradictive results which is exactly the opposite of what `tracing-futures`
/// tries to achive with its `.instrument` used in the current approach.
///
/// If needed, a declarative macro to substitute the |r| ... closure boilerplate could be introduced.
pub struct RequestSpan<E, R, H>(pub H)
where
E: Into<Box<dyn std::error::Error + Send + Sync>> + 'static,
R: Future<Output = Result<Response<Body>, E>> + Send + 'static,
H: Fn(Request<Body>) -> R + Send + Sync + 'static;
async fn logger(res: Response<Body>, info: RequestInfo) -> Result<Response<Body>, ApiError> {
let request_id = info.context::<RequestId>().unwrap_or_default().0;
impl<E, R, H> RequestSpan<E, R, H>
where
E: Into<Box<dyn std::error::Error + Send + Sync>> + 'static,
R: Future<Output = Result<Response<Body>, E>> + Send + 'static,
H: Fn(Request<Body>) -> R + Send + Sync + 'static,
{
/// Creates a tracing span around inner request handler and executes the request handler in the contex of that span.
/// Use as `|r| RequestSpan(my_handler).handle(r)` instead of `my_handler` as the request handler to get the span enabled.
pub async fn handle(self, request: Request<Body>) -> Result<Response<Body>, E> {
let request_id = request.context::<RequestId>().unwrap_or_default().0;
let method = request.method();
let path = request.uri().path();
let request_span = info_span!("request", %method, %path, %request_id);
// cannot factor out the Level to avoid the repetition
// because tracing can only work with const Level
// which is not the case here
let log_quietly = method == Method::GET;
async move {
if log_quietly {
debug!("Handling request");
} else {
info!("Handling request");
}
// Note that we reuse `error::handler` here and not returning and error at all,
// yet cannot use `!` directly in the method signature due to `routerify::RouterBuilder` limitation.
// Usage of the error handler also means that we expect only the `ApiError` errors to be raised in this call.
//
// Panics are not handled separately, there's a `tracing_panic_hook` from another module to do that globally.
match (self.0)(request).await {
Ok(response) => {
let response_status = response.status();
if log_quietly && response_status.is_success() {
debug!("Request handled, status: {response_status}");
} else {
info!("Request handled, status: {response_status}");
}
Ok(response)
}
Err(e) => Ok(error::handler(e.into()).await),
}
}
.instrument(request_span)
.await
if info.method() == Method::GET && res.status() == StatusCode::OK {
tracing::debug!(
"{} {} {} {}",
info.method(),
info.uri().path(),
request_id,
res.status()
);
} else {
tracing::info!(
"{} {} {} {}",
info.method(),
info.uri().path(),
request_id,
res.status()
);
}
Ok(res)
}
async fn prometheus_metrics_handler(_req: Request<Body>) -> Result<Response<Body>, ApiError> {
@@ -142,6 +96,12 @@ pub fn add_request_id_middleware<B: hyper::body::HttpBody + Send + Sync + 'stati
request_id.to_string()
}
};
if req.method() == Method::GET {
tracing::debug!("{} {} {}", req.method(), req.uri().path(), request_id);
} else {
tracing::info!("{} {} {}", req.method(), req.uri().path(), request_id);
}
req.set_context(RequestId(request_id));
Ok(req)
@@ -165,12 +125,11 @@ async fn add_request_id_header_to_response(
pub fn make_router() -> RouterBuilder<hyper::Body, ApiError> {
Router::builder()
.middleware(add_request_id_middleware())
.middleware(Middleware::post_with_info(logger))
.middleware(Middleware::post_with_info(
add_request_id_header_to_response,
))
.get("/metrics", |r| {
RequestSpan(prometheus_metrics_handler).handle(r)
})
.get("/metrics", prometheus_metrics_handler)
.err_handler(error::handler)
}
@@ -180,43 +139,40 @@ pub fn attach_openapi_ui(
spec_mount_path: &'static str,
ui_mount_path: &'static str,
) -> RouterBuilder<hyper::Body, ApiError> {
router_builder
.get(spec_mount_path, move |r| {
RequestSpan(move |_| async move { Ok(Response::builder().body(Body::from(spec)).unwrap()) })
.handle(r)
})
.get(ui_mount_path, move |r| RequestSpan( move |_| async move {
Ok(Response::builder().body(Body::from(format!(r#"
<!DOCTYPE html>
<html lang="en">
<head>
<title>rweb</title>
<link href="https://cdn.jsdelivr.net/npm/swagger-ui-dist@3/swagger-ui.css" rel="stylesheet">
</head>
<body>
<div id="swagger-ui"></div>
<script src="https://cdn.jsdelivr.net/npm/swagger-ui-dist@3/swagger-ui-bundle.js" charset="UTF-8"> </script>
<script>
window.onload = function() {{
const ui = SwaggerUIBundle({{
"dom_id": "\#swagger-ui",
presets: [
SwaggerUIBundle.presets.apis,
SwaggerUIBundle.SwaggerUIStandalonePreset
],
layout: "BaseLayout",
deepLinking: true,
showExtensions: true,
showCommonExtensions: true,
url: "{}",
}})
window.ui = ui;
}};
</script>
</body>
</html>
"#, spec_mount_path))).unwrap())
}).handle(r))
router_builder.get(spec_mount_path, move |_| async move {
Ok(Response::builder().body(Body::from(spec)).unwrap())
}).get(ui_mount_path, move |_| async move {
Ok(Response::builder().body(Body::from(format!(r#"
<!DOCTYPE html>
<html lang="en">
<head>
<title>rweb</title>
<link href="https://cdn.jsdelivr.net/npm/swagger-ui-dist@3/swagger-ui.css" rel="stylesheet">
</head>
<body>
<div id="swagger-ui"></div>
<script src="https://cdn.jsdelivr.net/npm/swagger-ui-dist@3/swagger-ui-bundle.js" charset="UTF-8"> </script>
<script>
window.onload = function() {{
const ui = SwaggerUIBundle({{
"dom_id": "\#swagger-ui",
presets: [
SwaggerUIBundle.presets.apis,
SwaggerUIBundle.SwaggerUIStandalonePreset
],
layout: "BaseLayout",
deepLinking: true,
showExtensions: true,
showCommonExtensions: true,
url: "{}",
}})
window.ui = ui;
}};
</script>
</body>
</html>
"#, spec_mount_path))).unwrap())
})
}
fn parse_token(header_value: &str) -> Result<&str, ApiError> {
@@ -278,7 +234,7 @@ where
async move {
let headers = response.headers_mut();
if headers.contains_key(&name) {
warn!(
tracing::warn!(
"{} response already contains header {:?}",
request_info.uri(),
&name,
@@ -318,7 +274,7 @@ pub fn serve_thread_main<S>(
where
S: Future<Output = ()> + Send + Sync,
{
info!("Starting an HTTP endpoint at {}", listener.local_addr()?);
tracing::info!("Starting an HTTP endpoint at {}", listener.local_addr()?);
// Create a Service from the router above to handle incoming requests.
let service = RouterService::new(router_builder.build().map_err(|err| anyhow!(err))?).unwrap();

View File

@@ -13,6 +13,8 @@ pub mod simple_rcu;
pub mod vec_map;
pub mod bin_ser;
pub mod postgres_backend;
pub mod postgres_backend_async;
// helper functions for creating and fsyncing
pub mod crashsafe;
@@ -25,6 +27,9 @@ pub mod id;
// http endpoint utils
pub mod http;
// socket splitting utils
pub mod sock_split;
// common log initialisation routine
pub mod logging;
@@ -49,8 +54,6 @@ pub mod fs_ext;
pub mod history_buffer;
pub mod measured_stream;
/// use with fail::cfg("$name", "return(2000)")
#[macro_export]
macro_rules! failpoint_sleep_millis_async {

View File

@@ -1,77 +0,0 @@
use pin_project_lite::pin_project;
use std::pin::Pin;
use std::{io, task};
use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
pin_project! {
/// This stream tracks all writes and calls user provided
/// callback when the underlying stream is flushed.
pub struct MeasuredStream<S, R, W> {
#[pin]
stream: S,
write_count: usize,
inc_read_count: R,
inc_write_count: W,
}
}
impl<S, R, W> MeasuredStream<S, R, W> {
pub fn new(stream: S, inc_read_count: R, inc_write_count: W) -> Self {
Self {
stream,
write_count: 0,
inc_read_count,
inc_write_count,
}
}
}
impl<S: AsyncRead + Unpin, R: FnMut(usize), W> AsyncRead for MeasuredStream<S, R, W> {
fn poll_read(
self: Pin<&mut Self>,
context: &mut task::Context<'_>,
buf: &mut ReadBuf<'_>,
) -> task::Poll<io::Result<()>> {
let this = self.project();
let filled = buf.filled().len();
this.stream.poll_read(context, buf).map_ok(|()| {
let cnt = buf.filled().len() - filled;
// Increment the read count.
(this.inc_read_count)(cnt);
})
}
}
impl<S: AsyncWrite + Unpin, R, W: FnMut(usize)> AsyncWrite for MeasuredStream<S, R, W> {
fn poll_write(
self: Pin<&mut Self>,
context: &mut task::Context<'_>,
buf: &[u8],
) -> task::Poll<io::Result<usize>> {
let this = self.project();
this.stream.poll_write(context, buf).map_ok(|cnt| {
// Increment the write count.
*this.write_count += cnt;
cnt
})
}
fn poll_flush(
self: Pin<&mut Self>,
context: &mut task::Context<'_>,
) -> task::Poll<io::Result<()>> {
let this = self.project();
this.stream.poll_flush(context).map_ok(|()| {
// Call the user provided callback and reset the write count.
(this.inc_write_count)(*this.write_count);
*this.write_count = 0;
})
}
fn poll_shutdown(
self: Pin<&mut Self>,
context: &mut task::Context<'_>,
) -> task::Poll<io::Result<()>> {
self.project().stream.poll_shutdown(context)
}
}

View File

@@ -0,0 +1,485 @@
//! Server-side synchronous Postgres connection, as limited as we need.
//! To use, create PostgresBackend and run() it, passing the Handler
//! implementation determining how to process the queries. Currently its API
//! is rather narrow, but we can extend it once required.
use crate::postgres_backend_async::{log_query_error, short_error, QueryError};
use crate::sock_split::{BidiStream, ReadStream, WriteStream};
use anyhow::Context;
use bytes::{Bytes, BytesMut};
use pq_proto::{BeMessage, FeMessage, FeStartupPacket};
use serde::{Deserialize, Serialize};
use std::fmt;
use std::io::{self, Write};
use std::net::{Shutdown, SocketAddr, TcpStream};
use std::str::FromStr;
use std::sync::Arc;
use std::time::Duration;
use tracing::*;
pub trait Handler {
/// Handle single query.
/// postgres_backend will issue ReadyForQuery after calling this (this
/// might be not what we want after CopyData streaming, but currently we don't
/// care).
fn process_query(
&mut self,
pgb: &mut PostgresBackend,
query_string: &str,
) -> Result<(), QueryError>;
/// Called on startup packet receival, allows to process params.
///
/// If Ok(false) is returned postgres_backend will skip auth -- that is needed for new users
/// creation is the proxy code. That is quite hacky and ad-hoc solution, may be we could allow
/// to override whole init logic in implementations.
fn startup(
&mut self,
_pgb: &mut PostgresBackend,
_sm: &FeStartupPacket,
) -> Result<(), QueryError> {
Ok(())
}
/// Check auth jwt
fn check_auth_jwt(
&mut self,
_pgb: &mut PostgresBackend,
_jwt_response: &[u8],
) -> Result<(), QueryError> {
Err(QueryError::Other(anyhow::anyhow!("JWT auth failed")))
}
fn is_shutdown_requested(&self) -> bool {
false
}
}
/// PostgresBackend protocol state.
/// XXX: The order of the constructors matters.
#[derive(Clone, Copy, PartialEq, Eq, PartialOrd)]
pub enum ProtoState {
Initialization,
Encrypted,
Authentication,
Established,
}
#[derive(Debug, PartialEq, Eq, Clone, Copy, Serialize, Deserialize)]
pub enum AuthType {
Trust,
// This mimics postgres's AuthenticationCleartextPassword but instead of password expects JWT
NeonJWT,
}
impl FromStr for AuthType {
type Err = anyhow::Error;
fn from_str(s: &str) -> Result<Self, Self::Err> {
match s {
"Trust" => Ok(Self::Trust),
"NeonJWT" => Ok(Self::NeonJWT),
_ => anyhow::bail!("invalid value \"{s}\" for auth type"),
}
}
}
impl fmt::Display for AuthType {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.write_str(match self {
AuthType::Trust => "Trust",
AuthType::NeonJWT => "NeonJWT",
})
}
}
#[derive(Clone, Copy)]
pub enum ProcessMsgResult {
Continue,
Break,
}
/// Always-writeable sock_split stream.
/// May not be readable. See [`PostgresBackend::take_stream_in`]
pub enum Stream {
Bidirectional(BidiStream),
WriteOnly(WriteStream),
}
impl Stream {
fn shutdown(&mut self, how: Shutdown) -> io::Result<()> {
match self {
Self::Bidirectional(bidi_stream) => bidi_stream.shutdown(how),
Self::WriteOnly(write_stream) => write_stream.shutdown(how),
}
}
}
impl io::Write for Stream {
fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
match self {
Self::Bidirectional(bidi_stream) => bidi_stream.write(buf),
Self::WriteOnly(write_stream) => write_stream.write(buf),
}
}
fn flush(&mut self) -> io::Result<()> {
match self {
Self::Bidirectional(bidi_stream) => bidi_stream.flush(),
Self::WriteOnly(write_stream) => write_stream.flush(),
}
}
}
pub struct PostgresBackend {
stream: Option<Stream>,
// Output buffer. c.f. BeMessage::write why we are using BytesMut here.
buf_out: BytesMut,
pub state: ProtoState,
auth_type: AuthType,
peer_addr: SocketAddr,
pub tls_config: Option<Arc<rustls::ServerConfig>>,
}
pub fn query_from_cstring(query_string: Bytes) -> Vec<u8> {
let mut query_string = query_string.to_vec();
if let Some(ch) = query_string.last() {
if *ch == 0 {
query_string.pop();
}
}
query_string
}
// Helper function for socket read loops
pub fn is_socket_read_timed_out(error: &anyhow::Error) -> bool {
for cause in error.chain() {
if let Some(io_error) = cause.downcast_ref::<io::Error>() {
if io_error.kind() == std::io::ErrorKind::WouldBlock {
return true;
}
}
}
false
}
// Cast a byte slice to a string slice, dropping null terminator if there's one.
fn cstr_to_str(bytes: &[u8]) -> anyhow::Result<&str> {
let without_null = bytes.strip_suffix(&[0]).unwrap_or(bytes);
std::str::from_utf8(without_null).map_err(|e| e.into())
}
impl PostgresBackend {
pub fn new(
socket: TcpStream,
auth_type: AuthType,
tls_config: Option<Arc<rustls::ServerConfig>>,
set_read_timeout: bool,
) -> io::Result<Self> {
let peer_addr = socket.peer_addr()?;
if set_read_timeout {
socket
.set_read_timeout(Some(Duration::from_secs(5)))
.unwrap();
}
Ok(Self {
stream: Some(Stream::Bidirectional(BidiStream::from_tcp(socket))),
buf_out: BytesMut::with_capacity(10 * 1024),
state: ProtoState::Initialization,
auth_type,
tls_config,
peer_addr,
})
}
pub fn into_stream(self) -> Stream {
self.stream.unwrap()
}
/// Get direct reference (into the Option) to the read stream.
fn get_stream_in(&mut self) -> anyhow::Result<&mut BidiStream> {
match &mut self.stream {
Some(Stream::Bidirectional(stream)) => Ok(stream),
_ => anyhow::bail!("reader taken"),
}
}
pub fn get_peer_addr(&self) -> &SocketAddr {
&self.peer_addr
}
pub fn take_stream_in(&mut self) -> Option<ReadStream> {
let stream = self.stream.take();
match stream {
Some(Stream::Bidirectional(bidi_stream)) => {
let (read, write) = bidi_stream.split();
self.stream = Some(Stream::WriteOnly(write));
Some(read)
}
stream => {
self.stream = stream;
None
}
}
}
/// Read full message or return None if connection is closed.
pub fn read_message(&mut self) -> Result<Option<FeMessage>, QueryError> {
let (state, stream) = (self.state, self.get_stream_in()?);
use ProtoState::*;
match state {
Initialization | Encrypted => FeStartupPacket::read(stream),
Authentication | Established => FeMessage::read(stream),
}
.map_err(QueryError::from)
}
/// Write message into internal output buffer.
pub fn write_message_noflush(&mut self, message: &BeMessage) -> io::Result<&mut Self> {
BeMessage::write(&mut self.buf_out, message)?;
Ok(self)
}
/// Flush output buffer into the socket.
pub fn flush(&mut self) -> io::Result<&mut Self> {
let stream = self.stream.as_mut().unwrap();
stream.write_all(&self.buf_out)?;
self.buf_out.clear();
Ok(self)
}
/// Write message into internal buffer and flush it.
pub fn write_message(&mut self, message: &BeMessage) -> io::Result<&mut Self> {
self.write_message_noflush(message)?;
self.flush()
}
// Wrapper for run_message_loop() that shuts down socket when we are done
pub fn run(mut self, handler: &mut impl Handler) -> Result<(), QueryError> {
let ret = self.run_message_loop(handler);
if let Some(stream) = self.stream.as_mut() {
let _ = stream.shutdown(Shutdown::Both);
}
ret
}
fn run_message_loop(&mut self, handler: &mut impl Handler) -> Result<(), QueryError> {
trace!("postgres backend to {:?} started", self.peer_addr);
let mut unnamed_query_string = Bytes::new();
while !handler.is_shutdown_requested() {
match self.read_message() {
Ok(message) => {
if let Some(msg) = message {
trace!("got message {msg:?}");
match self.process_message(handler, msg, &mut unnamed_query_string)? {
ProcessMsgResult::Continue => continue,
ProcessMsgResult::Break => break,
}
} else {
break;
}
}
Err(e) => {
if let QueryError::Other(e) = &e {
if is_socket_read_timed_out(e) {
continue;
}
}
return Err(e);
}
}
}
trace!("postgres backend to {:?} exited", self.peer_addr);
Ok(())
}
pub fn start_tls(&mut self) -> anyhow::Result<()> {
match self.stream.take() {
Some(Stream::Bidirectional(bidi_stream)) => {
let conn = rustls::ServerConnection::new(self.tls_config.clone().unwrap())?;
self.stream = Some(Stream::Bidirectional(bidi_stream.start_tls(conn)?));
Ok(())
}
stream => {
self.stream = stream;
anyhow::bail!("can't start TLs without bidi stream");
}
}
}
fn process_message(
&mut self,
handler: &mut impl Handler,
msg: FeMessage,
unnamed_query_string: &mut Bytes,
) -> Result<ProcessMsgResult, QueryError> {
// Allow only startup and password messages during auth. Otherwise client would be able to bypass auth
// TODO: change that to proper top-level match of protocol state with separate message handling for each state
if self.state < ProtoState::Established
&& !matches!(
msg,
FeMessage::PasswordMessage(_) | FeMessage::StartupPacket(_)
)
{
return Err(QueryError::Other(anyhow::anyhow!("protocol violation")));
}
let have_tls = self.tls_config.is_some();
match msg {
FeMessage::StartupPacket(m) => {
trace!("got startup message {m:?}");
match m {
FeStartupPacket::SslRequest => {
debug!("SSL requested");
self.write_message(&BeMessage::EncryptionResponse(have_tls))?;
if have_tls {
self.start_tls()?;
self.state = ProtoState::Encrypted;
}
}
FeStartupPacket::GssEncRequest => {
debug!("GSS requested");
self.write_message(&BeMessage::EncryptionResponse(false))?;
}
FeStartupPacket::StartupMessage { .. } => {
if have_tls && !matches!(self.state, ProtoState::Encrypted) {
self.write_message(&BeMessage::ErrorResponse(
"must connect with TLS",
None,
))?;
return Err(QueryError::Other(anyhow::anyhow!(
"client did not connect with TLS"
)));
}
// NB: startup() may change self.auth_type -- we are using that in proxy code
// to bypass auth for new users.
handler.startup(self, &m)?;
match self.auth_type {
AuthType::Trust => {
self.write_message_noflush(&BeMessage::AuthenticationOk)?
.write_message_noflush(&BeMessage::CLIENT_ENCODING)?
// The async python driver requires a valid server_version
.write_message_noflush(&BeMessage::server_version("14.1"))?
.write_message(&BeMessage::ReadyForQuery)?;
self.state = ProtoState::Established;
}
AuthType::NeonJWT => {
self.write_message(&BeMessage::AuthenticationCleartextPassword)?;
self.state = ProtoState::Authentication;
}
}
}
FeStartupPacket::CancelRequest { .. } => {
return Ok(ProcessMsgResult::Break);
}
}
}
FeMessage::PasswordMessage(m) => {
trace!("got password message '{:?}'", m);
assert!(self.state == ProtoState::Authentication);
match self.auth_type {
AuthType::Trust => unreachable!(),
AuthType::NeonJWT => {
let (_, jwt_response) = m.split_last().context("protocol violation")?;
if let Err(e) = handler.check_auth_jwt(self, jwt_response) {
self.write_message(&BeMessage::ErrorResponse(
&e.to_string(),
Some(e.pg_error_code()),
))?;
return Err(e);
}
}
}
self.write_message_noflush(&BeMessage::AuthenticationOk)?
.write_message_noflush(&BeMessage::CLIENT_ENCODING)?
.write_message(&BeMessage::ReadyForQuery)?;
self.state = ProtoState::Established;
}
FeMessage::Query(body) => {
// remove null terminator
let query_string = cstr_to_str(&body)?;
trace!("got query {query_string:?}");
if let Err(e) = handler.process_query(self, query_string) {
log_query_error(query_string, &e);
let short_error = short_error(&e);
self.write_message_noflush(&BeMessage::ErrorResponse(
&short_error,
Some(e.pg_error_code()),
))?;
}
self.write_message(&BeMessage::ReadyForQuery)?;
}
FeMessage::Parse(m) => {
*unnamed_query_string = m.query_string;
self.write_message(&BeMessage::ParseComplete)?;
}
FeMessage::Describe(_) => {
self.write_message_noflush(&BeMessage::ParameterDescription)?
.write_message(&BeMessage::NoData)?;
}
FeMessage::Bind(_) => {
self.write_message(&BeMessage::BindComplete)?;
}
FeMessage::Close(_) => {
self.write_message(&BeMessage::CloseComplete)?;
}
FeMessage::Execute(_) => {
let query_string = cstr_to_str(unnamed_query_string)?;
trace!("got execute {query_string:?}");
if let Err(e) = handler.process_query(self, query_string) {
log_query_error(query_string, &e);
self.write_message(&BeMessage::ErrorResponse(
&e.to_string(),
Some(e.pg_error_code()),
))?;
}
// NOTE there is no ReadyForQuery message. This handler is used
// for basebackup and it uses CopyOut which doesn't require
// ReadyForQuery message and backend just switches back to
// processing mode after sending CopyDone or ErrorResponse.
}
FeMessage::Sync => {
self.write_message(&BeMessage::ReadyForQuery)?;
}
FeMessage::Terminate => {
return Ok(ProcessMsgResult::Break);
}
// We prefer explicit pattern matching to wildcards, because
// this helps us spot the places where new variants are missing
FeMessage::CopyData(_) | FeMessage::CopyDone | FeMessage::CopyFail => {
return Err(QueryError::Other(anyhow::anyhow!(
"unexpected message type: {msg:?}"
)));
}
}
Ok(ProcessMsgResult::Continue)
}
}

View File

@@ -0,0 +1,634 @@
//! Server-side asynchronous Postgres connection, as limited as we need.
//! To use, create PostgresBackend and run() it, passing the Handler
//! implementation determining how to process the queries. Currently its API
//! is rather narrow, but we can extend it once required.
use crate::postgres_backend::AuthType;
use anyhow::Context;
use bytes::{Buf, Bytes, BytesMut};
use pq_proto::{BeMessage, ConnectionError, FeMessage, FeStartupPacket, SQLSTATE_INTERNAL_ERROR};
use std::io;
use std::net::SocketAddr;
use std::pin::Pin;
use std::sync::Arc;
use std::task::Poll;
use std::{future::Future, task::ready};
use tracing::{debug, error, info, trace};
use tokio::io::{AsyncRead, AsyncWrite, AsyncWriteExt, BufReader};
use tokio_rustls::TlsAcceptor;
pub fn is_expected_io_error(e: &io::Error) -> bool {
use io::ErrorKind::*;
matches!(
e.kind(),
ConnectionRefused | ConnectionAborted | ConnectionReset
)
}
/// An error, occurred during query processing:
/// either during the connection ([`ConnectionError`]) or before/after it.
#[derive(thiserror::Error, Debug)]
pub enum QueryError {
/// The connection was lost while processing the query.
#[error(transparent)]
Disconnected(#[from] ConnectionError),
/// Some other error
#[error(transparent)]
Other(#[from] anyhow::Error),
}
impl From<io::Error> for QueryError {
fn from(e: io::Error) -> Self {
Self::Disconnected(ConnectionError::Socket(e))
}
}
impl QueryError {
pub fn pg_error_code(&self) -> &'static [u8; 5] {
match self {
Self::Disconnected(_) => b"08006", // connection failure
Self::Other(_) => SQLSTATE_INTERNAL_ERROR, // internal error
}
}
}
#[async_trait::async_trait]
pub trait Handler {
/// Handle single query.
/// postgres_backend will issue ReadyForQuery after calling this (this
/// might be not what we want after CopyData streaming, but currently we don't
/// care).
async fn process_query(
&mut self,
pgb: &mut PostgresBackend,
query_string: &str,
) -> Result<(), QueryError>;
/// Called on startup packet receival, allows to process params.
///
/// If Ok(false) is returned postgres_backend will skip auth -- that is needed for new users
/// creation is the proxy code. That is quite hacky and ad-hoc solution, may be we could allow
/// to override whole init logic in implementations.
fn startup(
&mut self,
_pgb: &mut PostgresBackend,
_sm: &FeStartupPacket,
) -> Result<(), QueryError> {
Ok(())
}
/// Check auth jwt
fn check_auth_jwt(
&mut self,
_pgb: &mut PostgresBackend,
_jwt_response: &[u8],
) -> Result<(), QueryError> {
Err(QueryError::Other(anyhow::anyhow!("JWT auth failed")))
}
}
/// PostgresBackend protocol state.
/// XXX: The order of the constructors matters.
#[derive(Clone, Copy, PartialEq, Eq, PartialOrd)]
pub enum ProtoState {
Initialization,
Encrypted,
Authentication,
Established,
Closed,
}
#[derive(Clone, Copy)]
pub enum ProcessMsgResult {
Continue,
Break,
}
/// Always-writeable sock_split stream.
/// May not be readable. See [`PostgresBackend::take_stream_in`]
pub enum Stream {
Unencrypted(BufReader<tokio::net::TcpStream>),
Tls(Box<tokio_rustls::server::TlsStream<BufReader<tokio::net::TcpStream>>>),
Broken,
}
impl AsyncWrite for Stream {
fn poll_write(
self: Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
buf: &[u8],
) -> Poll<io::Result<usize>> {
match self.get_mut() {
Self::Unencrypted(stream) => Pin::new(stream).poll_write(cx, buf),
Self::Tls(stream) => Pin::new(stream).poll_write(cx, buf),
Self::Broken => unreachable!(),
}
}
fn poll_flush(self: Pin<&mut Self>, cx: &mut std::task::Context<'_>) -> Poll<io::Result<()>> {
match self.get_mut() {
Self::Unencrypted(stream) => Pin::new(stream).poll_flush(cx),
Self::Tls(stream) => Pin::new(stream).poll_flush(cx),
Self::Broken => unreachable!(),
}
}
fn poll_shutdown(
self: Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
) -> Poll<io::Result<()>> {
match self.get_mut() {
Self::Unencrypted(stream) => Pin::new(stream).poll_shutdown(cx),
Self::Tls(stream) => Pin::new(stream).poll_shutdown(cx),
Self::Broken => unreachable!(),
}
}
}
impl AsyncRead for Stream {
fn poll_read(
self: Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
buf: &mut tokio::io::ReadBuf<'_>,
) -> Poll<io::Result<()>> {
match self.get_mut() {
Self::Unencrypted(stream) => Pin::new(stream).poll_read(cx, buf),
Self::Tls(stream) => Pin::new(stream).poll_read(cx, buf),
Self::Broken => unreachable!(),
}
}
}
pub struct PostgresBackend {
stream: Stream,
// Output buffer. c.f. BeMessage::write why we are using BytesMut here.
// The data between 0 and "current position" as tracked by the bytes::Buf
// implementation of BytesMut, have already been written.
buf_out: BytesMut,
pub state: ProtoState,
auth_type: AuthType,
peer_addr: SocketAddr,
pub tls_config: Option<Arc<rustls::ServerConfig>>,
}
pub fn query_from_cstring(query_string: Bytes) -> Vec<u8> {
let mut query_string = query_string.to_vec();
if let Some(ch) = query_string.last() {
if *ch == 0 {
query_string.pop();
}
}
query_string
}
// Cast a byte slice to a string slice, dropping null terminator if there's one.
fn cstr_to_str(bytes: &[u8]) -> anyhow::Result<&str> {
let without_null = bytes.strip_suffix(&[0]).unwrap_or(bytes);
std::str::from_utf8(without_null).map_err(|e| e.into())
}
impl PostgresBackend {
pub fn new(
socket: tokio::net::TcpStream,
auth_type: AuthType,
tls_config: Option<Arc<rustls::ServerConfig>>,
) -> io::Result<Self> {
let peer_addr = socket.peer_addr()?;
Ok(Self {
stream: Stream::Unencrypted(BufReader::new(socket)),
buf_out: BytesMut::with_capacity(10 * 1024),
state: ProtoState::Initialization,
auth_type,
tls_config,
peer_addr,
})
}
pub fn get_peer_addr(&self) -> &SocketAddr {
&self.peer_addr
}
/// Read full message or return None if connection is closed.
pub async fn read_message(&mut self) -> Result<Option<FeMessage>, QueryError> {
use ProtoState::*;
match self.state {
Initialization | Encrypted => FeStartupPacket::read_fut(&mut self.stream).await,
Authentication | Established => FeMessage::read_fut(&mut self.stream).await,
Closed => Ok(None),
}
.map_err(QueryError::from)
}
/// Flush output buffer into the socket.
pub async fn flush(&mut self) -> io::Result<()> {
while self.buf_out.has_remaining() {
let bytes_written = self.stream.write(self.buf_out.chunk()).await?;
self.buf_out.advance(bytes_written);
}
self.buf_out.clear();
Ok(())
}
/// Write message into internal output buffer.
pub fn write_message(&mut self, message: &BeMessage<'_>) -> io::Result<&mut Self> {
BeMessage::write(&mut self.buf_out, message)?;
Ok(self)
}
/// Returns an AsyncWrite implementation that wraps all the data written
/// to it in CopyData messages, and writes them to the connection
///
/// The caller is responsible for sending CopyOutResponse and CopyDone messages.
pub fn copyout_writer(&mut self) -> CopyDataWriter {
CopyDataWriter { pgb: self }
}
/// A polling function that tries to write all the data from 'buf_out' to the
/// underlying stream.
fn poll_write_buf(
&mut self,
cx: &mut std::task::Context<'_>,
) -> Poll<Result<(), std::io::Error>> {
while self.buf_out.has_remaining() {
match ready!(Pin::new(&mut self.stream).poll_write(cx, self.buf_out.chunk())) {
Ok(bytes_written) => self.buf_out.advance(bytes_written),
Err(err) => return Poll::Ready(Err(err)),
}
}
Poll::Ready(Ok(()))
}
fn poll_flush(&mut self, cx: &mut std::task::Context<'_>) -> Poll<Result<(), std::io::Error>> {
Pin::new(&mut self.stream).poll_flush(cx)
}
// Wrapper for run_message_loop() that shuts down socket when we are done
pub async fn run<F, S>(
mut self,
handler: &mut impl Handler,
shutdown_watcher: F,
) -> Result<(), QueryError>
where
F: Fn() -> S,
S: Future,
{
let ret = self.run_message_loop(handler, shutdown_watcher).await;
let _ = self.stream.shutdown();
ret
}
async fn run_message_loop<F, S>(
&mut self,
handler: &mut impl Handler,
shutdown_watcher: F,
) -> Result<(), QueryError>
where
F: Fn() -> S,
S: Future,
{
trace!("postgres backend to {:?} started", self.peer_addr);
tokio::select!(
biased;
_ = shutdown_watcher() => {
// We were requested to shut down.
tracing::info!("shutdown request received during handshake");
return Ok(())
},
result = async {
while self.state < ProtoState::Established {
if let Some(msg) = self.read_message().await? {
trace!("got message {msg:?} during handshake");
match self.process_handshake_message(handler, msg).await? {
ProcessMsgResult::Continue => {
self.flush().await?;
continue;
}
ProcessMsgResult::Break => {
trace!("postgres backend to {:?} exited during handshake", self.peer_addr);
return Ok(());
}
}
} else {
trace!("postgres backend to {:?} exited during handshake", self.peer_addr);
return Ok(());
}
}
Ok::<(), QueryError>(())
} => {
// Handshake complete.
result?;
}
);
// Authentication completed
let mut query_string = Bytes::new();
while let Some(msg) = tokio::select!(
biased;
_ = shutdown_watcher() => {
// We were requested to shut down.
tracing::info!("shutdown request received in run_message_loop");
Ok(None)
},
msg = self.read_message() => { msg },
)? {
trace!("got message {:?}", msg);
let result = self.process_message(handler, msg, &mut query_string).await;
self.flush().await?;
match result? {
ProcessMsgResult::Continue => {
self.flush().await?;
continue;
}
ProcessMsgResult::Break => break,
}
}
trace!("postgres backend to {:?} exited", self.peer_addr);
Ok(())
}
async fn start_tls(&mut self) -> anyhow::Result<()> {
if let Stream::Unencrypted(plain_stream) =
std::mem::replace(&mut self.stream, Stream::Broken)
{
let acceptor = TlsAcceptor::from(self.tls_config.clone().unwrap());
let tls_stream = acceptor.accept(plain_stream).await?;
self.stream = Stream::Tls(Box::new(tls_stream));
return Ok(());
};
anyhow::bail!("TLS already started");
}
async fn process_handshake_message(
&mut self,
handler: &mut impl Handler,
msg: FeMessage,
) -> Result<ProcessMsgResult, QueryError> {
assert!(self.state < ProtoState::Established);
let have_tls = self.tls_config.is_some();
match msg {
FeMessage::StartupPacket(m) => {
trace!("got startup message {m:?}");
match m {
FeStartupPacket::SslRequest => {
debug!("SSL requested");
self.write_message(&BeMessage::EncryptionResponse(have_tls))?;
if have_tls {
self.start_tls().await?;
self.state = ProtoState::Encrypted;
}
}
FeStartupPacket::GssEncRequest => {
debug!("GSS requested");
self.write_message(&BeMessage::EncryptionResponse(false))?;
}
FeStartupPacket::StartupMessage { .. } => {
if have_tls && !matches!(self.state, ProtoState::Encrypted) {
self.write_message(&BeMessage::ErrorResponse(
"must connect with TLS",
None,
))?;
return Err(QueryError::Other(anyhow::anyhow!(
"client did not connect with TLS"
)));
}
// NB: startup() may change self.auth_type -- we are using that in proxy code
// to bypass auth for new users.
handler.startup(self, &m)?;
match self.auth_type {
AuthType::Trust => {
self.write_message(&BeMessage::AuthenticationOk)?
.write_message(&BeMessage::CLIENT_ENCODING)?
// The async python driver requires a valid server_version
.write_message(&BeMessage::server_version("14.1"))?
.write_message(&BeMessage::ReadyForQuery)?;
self.state = ProtoState::Established;
}
AuthType::NeonJWT => {
self.write_message(&BeMessage::AuthenticationCleartextPassword)?;
self.state = ProtoState::Authentication;
}
}
}
FeStartupPacket::CancelRequest { .. } => {
self.state = ProtoState::Closed;
return Ok(ProcessMsgResult::Break);
}
}
}
FeMessage::PasswordMessage(m) => {
trace!("got password message '{:?}'", m);
assert!(self.state == ProtoState::Authentication);
match self.auth_type {
AuthType::Trust => unreachable!(),
AuthType::NeonJWT => {
let (_, jwt_response) = m.split_last().context("protocol violation")?;
if let Err(e) = handler.check_auth_jwt(self, jwt_response) {
self.write_message(&BeMessage::ErrorResponse(
&e.to_string(),
Some(e.pg_error_code()),
))?;
return Err(e);
}
}
}
self.write_message(&BeMessage::AuthenticationOk)?
.write_message(&BeMessage::CLIENT_ENCODING)?
.write_message(&BeMessage::ReadyForQuery)?;
self.state = ProtoState::Established;
}
_ => {
self.state = ProtoState::Closed;
return Ok(ProcessMsgResult::Break);
}
}
Ok(ProcessMsgResult::Continue)
}
async fn process_message(
&mut self,
handler: &mut impl Handler,
msg: FeMessage,
unnamed_query_string: &mut Bytes,
) -> Result<ProcessMsgResult, QueryError> {
// Allow only startup and password messages during auth. Otherwise client would be able to bypass auth
// TODO: change that to proper top-level match of protocol state with separate message handling for each state
assert!(self.state == ProtoState::Established);
match msg {
FeMessage::StartupPacket(_) | FeMessage::PasswordMessage(_) => {
return Err(QueryError::Other(anyhow::anyhow!("protocol violation")));
}
FeMessage::Query(body) => {
// remove null terminator
let query_string = cstr_to_str(&body)?;
trace!("got query {query_string:?}");
if let Err(e) = handler.process_query(self, query_string).await {
log_query_error(query_string, &e);
let short_error = short_error(&e);
self.write_message(&BeMessage::ErrorResponse(
&short_error,
Some(e.pg_error_code()),
))?;
}
self.write_message(&BeMessage::ReadyForQuery)?;
}
FeMessage::Parse(m) => {
*unnamed_query_string = m.query_string;
self.write_message(&BeMessage::ParseComplete)?;
}
FeMessage::Describe(_) => {
self.write_message(&BeMessage::ParameterDescription)?
.write_message(&BeMessage::NoData)?;
}
FeMessage::Bind(_) => {
self.write_message(&BeMessage::BindComplete)?;
}
FeMessage::Close(_) => {
self.write_message(&BeMessage::CloseComplete)?;
}
FeMessage::Execute(_) => {
let query_string = cstr_to_str(unnamed_query_string)?;
trace!("got execute {query_string:?}");
if let Err(e) = handler.process_query(self, query_string).await {
log_query_error(query_string, &e);
self.write_message(&BeMessage::ErrorResponse(
&e.to_string(),
Some(e.pg_error_code()),
))?;
}
// NOTE there is no ReadyForQuery message. This handler is used
// for basebackup and it uses CopyOut which doesn't require
// ReadyForQuery message and backend just switches back to
// processing mode after sending CopyDone or ErrorResponse.
}
FeMessage::Sync => {
self.write_message(&BeMessage::ReadyForQuery)?;
}
FeMessage::Terminate => {
return Ok(ProcessMsgResult::Break);
}
// We prefer explicit pattern matching to wildcards, because
// this helps us spot the places where new variants are missing
FeMessage::CopyData(_) | FeMessage::CopyDone | FeMessage::CopyFail => {
return Err(QueryError::Other(anyhow::anyhow!(
"unexpected message type: {:?}",
msg
)));
}
}
Ok(ProcessMsgResult::Continue)
}
}
///
/// A futures::AsyncWrite implementation that wraps all data written to it in CopyData
/// messages.
///
pub struct CopyDataWriter<'a> {
pgb: &'a mut PostgresBackend,
}
impl<'a> AsyncWrite for CopyDataWriter<'a> {
fn poll_write(
self: Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
buf: &[u8],
) -> Poll<Result<usize, std::io::Error>> {
let this = self.get_mut();
// It's not strictly required to flush between each message, but makes it easier
// to view in wireshark, and usually the messages that the callers write are
// decently-sized anyway.
match ready!(this.pgb.poll_write_buf(cx)) {
Ok(()) => {}
Err(err) => return Poll::Ready(Err(err)),
}
// CopyData
// XXX: if the input is large, we should split it into multiple messages.
// Not sure what the threshold should be, but the ultimate hard limit is that
// the length cannot exceed u32.
this.pgb.write_message(&BeMessage::CopyData(buf))?;
Poll::Ready(Ok(buf.len()))
}
fn poll_flush(
self: Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
) -> Poll<Result<(), std::io::Error>> {
let this = self.get_mut();
match ready!(this.pgb.poll_write_buf(cx)) {
Ok(()) => {}
Err(err) => return Poll::Ready(Err(err)),
}
this.pgb.poll_flush(cx)
}
fn poll_shutdown(
self: Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
) -> Poll<Result<(), std::io::Error>> {
let this = self.get_mut();
match ready!(this.pgb.poll_write_buf(cx)) {
Ok(()) => {}
Err(err) => return Poll::Ready(Err(err)),
}
this.pgb.poll_flush(cx)
}
}
pub fn short_error(e: &QueryError) -> String {
match e {
QueryError::Disconnected(connection_error) => connection_error.to_string(),
QueryError::Other(e) => format!("{e:#}"),
}
}
pub(super) fn log_query_error(query: &str, e: &QueryError) {
match e {
QueryError::Disconnected(ConnectionError::Socket(io_error)) => {
if is_expected_io_error(io_error) {
info!("query handler for '{query}' failed with expected io error: {io_error}");
} else {
error!("query handler for '{query}' failed with io error: {io_error}");
}
}
QueryError::Disconnected(other_connection_error) => {
error!("query handler for '{query}' failed with connection error: {other_connection_error:?}")
}
QueryError::Other(e) => {
error!("query handler for '{query}' failed: {e:?}");
}
}
}

View File

@@ -0,0 +1,206 @@
use std::{
io::{self, BufReader, Write},
net::{Shutdown, TcpStream},
sync::Arc,
};
use rustls::Connection;
/// Wrapper supporting reads of a shared TcpStream.
pub struct ArcTcpRead(Arc<TcpStream>);
impl io::Read for ArcTcpRead {
fn read(&mut self, buf: &mut [u8]) -> std::io::Result<usize> {
(&*self.0).read(buf)
}
}
impl std::ops::Deref for ArcTcpRead {
type Target = TcpStream;
fn deref(&self) -> &Self::Target {
self.0.deref()
}
}
/// Wrapper around a TCP Stream supporting buffered reads.
pub struct BufStream(BufReader<ArcTcpRead>);
impl io::Read for BufStream {
fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
self.0.read(buf)
}
}
impl io::Write for BufStream {
fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
self.get_ref().write(buf)
}
fn flush(&mut self) -> io::Result<()> {
self.get_ref().flush()
}
}
impl BufStream {
/// Unwrap into the internal BufReader.
fn into_reader(self) -> BufReader<ArcTcpRead> {
self.0
}
/// Returns a reference to the underlying TcpStream.
fn get_ref(&self) -> &TcpStream {
&self.0.get_ref().0
}
}
pub enum ReadStream {
Tcp(BufReader<ArcTcpRead>),
Tls(rustls_split::ReadHalf),
}
impl io::Read for ReadStream {
fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
match self {
Self::Tcp(reader) => reader.read(buf),
Self::Tls(read_half) => read_half.read(buf),
}
}
}
impl ReadStream {
pub fn shutdown(&mut self, how: Shutdown) -> io::Result<()> {
match self {
Self::Tcp(stream) => stream.get_ref().shutdown(how),
Self::Tls(write_half) => write_half.shutdown(how),
}
}
}
pub enum WriteStream {
Tcp(Arc<TcpStream>),
Tls(rustls_split::WriteHalf),
}
impl WriteStream {
pub fn shutdown(&mut self, how: Shutdown) -> io::Result<()> {
match self {
Self::Tcp(stream) => stream.shutdown(how),
Self::Tls(write_half) => write_half.shutdown(how),
}
}
}
impl io::Write for WriteStream {
fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
match self {
Self::Tcp(stream) => stream.as_ref().write(buf),
Self::Tls(write_half) => write_half.write(buf),
}
}
fn flush(&mut self) -> io::Result<()> {
match self {
Self::Tcp(stream) => stream.as_ref().flush(),
Self::Tls(write_half) => write_half.flush(),
}
}
}
type TlsStream<T> = rustls::StreamOwned<rustls::ServerConnection, T>;
pub enum BidiStream {
Tcp(BufStream),
/// This variant is boxed, because [`rustls::ServerConnection`] is quite larger than [`BufStream`].
Tls(Box<TlsStream<BufStream>>),
}
impl BidiStream {
pub fn from_tcp(stream: TcpStream) -> Self {
Self::Tcp(BufStream(BufReader::new(ArcTcpRead(Arc::new(stream)))))
}
pub fn shutdown(&mut self, how: Shutdown) -> io::Result<()> {
match self {
Self::Tcp(stream) => stream.get_ref().shutdown(how),
Self::Tls(tls_boxed) => {
if how == Shutdown::Read {
tls_boxed.sock.get_ref().shutdown(how)
} else {
tls_boxed.conn.send_close_notify();
let res = tls_boxed.flush();
tls_boxed.sock.get_ref().shutdown(how)?;
res
}
}
}
}
/// Split the bi-directional stream into two owned read and write halves.
pub fn split(self) -> (ReadStream, WriteStream) {
match self {
Self::Tcp(stream) => {
let reader = stream.into_reader();
let stream: Arc<TcpStream> = reader.get_ref().0.clone();
(ReadStream::Tcp(reader), WriteStream::Tcp(stream))
}
Self::Tls(tls_boxed) => {
let reader = tls_boxed.sock.into_reader();
let buffer_data = reader.buffer().to_owned();
let read_buf_cfg = rustls_split::BufCfg::with_data(buffer_data, 8192);
let write_buf_cfg = rustls_split::BufCfg::with_capacity(8192);
// TODO would be nice to avoid the Arc here
let socket = Arc::try_unwrap(reader.into_inner().0).unwrap();
let (read_half, write_half) = rustls_split::split(
socket,
Connection::Server(tls_boxed.conn),
read_buf_cfg,
write_buf_cfg,
);
(ReadStream::Tls(read_half), WriteStream::Tls(write_half))
}
}
}
pub fn start_tls(self, mut conn: rustls::ServerConnection) -> io::Result<Self> {
match self {
Self::Tcp(mut stream) => {
conn.complete_io(&mut stream)?;
assert!(!conn.is_handshaking());
Ok(Self::Tls(Box::new(TlsStream::new(conn, stream))))
}
Self::Tls { .. } => Err(io::Error::new(
io::ErrorKind::InvalidInput,
"TLS is already started on this stream",
)),
}
}
}
impl io::Read for BidiStream {
fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
match self {
Self::Tcp(stream) => stream.read(buf),
Self::Tls(tls_boxed) => tls_boxed.read(buf),
}
}
}
impl io::Write for BidiStream {
fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
match self {
Self::Tcp(stream) => stream.write(buf),
Self::Tls(tls_boxed) => tls_boxed.write(buf),
}
}
fn flush(&mut self) -> io::Result<()> {
match self {
Self::Tcp(stream) => stream.flush(),
Self::Tls(tls_boxed) => tls_boxed.flush(),
}
}
}

View File

@@ -0,0 +1,238 @@
use std::{
collections::HashMap,
io::{Cursor, Read, Write},
net::{TcpListener, TcpStream},
sync::Arc,
};
use byteorder::{BigEndian, ReadBytesExt, WriteBytesExt};
use bytes::{Buf, BufMut, Bytes, BytesMut};
use once_cell::sync::Lazy;
use utils::{
postgres_backend::{AuthType, Handler, PostgresBackend},
postgres_backend_async::QueryError,
};
fn make_tcp_pair() -> (TcpStream, TcpStream) {
let listener = TcpListener::bind("127.0.0.1:0").unwrap();
let addr = listener.local_addr().unwrap();
let client_stream = TcpStream::connect(addr).unwrap();
let (server_stream, _) = listener.accept().unwrap();
(server_stream, client_stream)
}
static KEY: Lazy<rustls::PrivateKey> = Lazy::new(|| {
let mut cursor = Cursor::new(include_bytes!("key.pem"));
rustls::PrivateKey(rustls_pemfile::rsa_private_keys(&mut cursor).unwrap()[0].clone())
});
static CERT: Lazy<rustls::Certificate> = Lazy::new(|| {
let mut cursor = Cursor::new(include_bytes!("cert.pem"));
rustls::Certificate(rustls_pemfile::certs(&mut cursor).unwrap()[0].clone())
});
#[test]
// [false-positive](https://github.com/rust-lang/rust-clippy/issues/9274),
// we resize the vector so doing some modifications after all
#[allow(clippy::read_zero_byte_vec)]
fn ssl() {
let (mut client_sock, server_sock) = make_tcp_pair();
const QUERY: &str = "hello world";
let client_jh = std::thread::spawn(move || {
// SSLRequest
client_sock.write_u32::<BigEndian>(8).unwrap();
client_sock.write_u32::<BigEndian>(80877103).unwrap();
let ssl_response = client_sock.read_u8().unwrap();
assert_eq!(b'S', ssl_response);
let cfg = rustls::ClientConfig::builder()
.with_safe_defaults()
.with_root_certificates({
let mut store = rustls::RootCertStore::empty();
store.add(&CERT).unwrap();
store
})
.with_no_client_auth();
let client_config = Arc::new(cfg);
let dns_name = "localhost".try_into().unwrap();
let mut conn = rustls::ClientConnection::new(client_config, dns_name).unwrap();
conn.complete_io(&mut client_sock).unwrap();
assert!(!conn.is_handshaking());
let mut stream = rustls::Stream::new(&mut conn, &mut client_sock);
// StartupMessage
stream.write_u32::<BigEndian>(9).unwrap();
stream.write_u32::<BigEndian>(196608).unwrap();
stream.write_u8(0).unwrap();
stream.flush().unwrap();
// wait for ReadyForQuery
let mut msg_buf = Vec::new();
loop {
let msg = stream.read_u8().unwrap();
let size = stream.read_u32::<BigEndian>().unwrap() - 4;
msg_buf.resize(size as usize, 0);
stream.read_exact(&mut msg_buf).unwrap();
if msg == b'Z' {
// ReadyForQuery
break;
}
}
// Query
stream.write_u8(b'Q').unwrap();
stream
.write_u32::<BigEndian>(4u32 + QUERY.len() as u32)
.unwrap();
stream.write_all(QUERY.as_ref()).unwrap();
stream.flush().unwrap();
// ReadyForQuery
let msg = stream.read_u8().unwrap();
assert_eq!(msg, b'Z');
});
struct TestHandler {
got_query: bool,
}
impl Handler for TestHandler {
fn process_query(
&mut self,
_pgb: &mut PostgresBackend,
query_string: &str,
) -> Result<(), QueryError> {
self.got_query = query_string == QUERY;
Ok(())
}
}
let mut handler = TestHandler { got_query: false };
let cfg = rustls::ServerConfig::builder()
.with_safe_defaults()
.with_no_client_auth()
.with_single_cert(vec![CERT.clone()], KEY.clone())
.unwrap();
let tls_config = Some(Arc::new(cfg));
let pgb = PostgresBackend::new(server_sock, AuthType::Trust, tls_config, true).unwrap();
pgb.run(&mut handler).unwrap();
assert!(handler.got_query);
client_jh.join().unwrap();
// TODO consider shutdown behavior
}
#[test]
fn no_ssl() {
let (mut client_sock, server_sock) = make_tcp_pair();
let client_jh = std::thread::spawn(move || {
let mut buf = BytesMut::new();
// SSLRequest
buf.put_u32(8);
buf.put_u32(80877103);
client_sock.write_all(&buf).unwrap();
buf.clear();
let ssl_response = client_sock.read_u8().unwrap();
assert_eq!(b'N', ssl_response);
});
struct TestHandler;
impl Handler for TestHandler {
fn process_query(
&mut self,
_pgb: &mut PostgresBackend,
_query_string: &str,
) -> Result<(), QueryError> {
panic!()
}
}
let mut handler = TestHandler;
let pgb = PostgresBackend::new(server_sock, AuthType::Trust, None, true).unwrap();
pgb.run(&mut handler).unwrap();
client_jh.join().unwrap();
}
#[test]
fn server_forces_ssl() {
let (mut client_sock, server_sock) = make_tcp_pair();
let client_jh = std::thread::spawn(move || {
// StartupMessage
client_sock.write_u32::<BigEndian>(9).unwrap();
client_sock.write_u32::<BigEndian>(196608).unwrap();
client_sock.write_u8(0).unwrap();
client_sock.flush().unwrap();
// ErrorResponse
assert_eq!(client_sock.read_u8().unwrap(), b'E');
let len = client_sock.read_u32::<BigEndian>().unwrap() - 4;
let mut body = vec![0; len as usize];
client_sock.read_exact(&mut body).unwrap();
let mut body = Bytes::from(body);
let mut errors = HashMap::new();
loop {
let field_type = body.get_u8();
if field_type == 0u8 {
break;
}
let end_idx = body.iter().position(|&b| b == 0u8).unwrap();
let mut value = body.split_to(end_idx + 1);
assert_eq!(value[end_idx], 0u8);
value.truncate(end_idx);
let old = errors.insert(field_type, value);
assert!(old.is_none());
}
assert!(!body.has_remaining());
assert_eq!("must connect with TLS", errors.get(&b'M').unwrap());
// TODO read failure
});
struct TestHandler;
impl Handler for TestHandler {
fn process_query(
&mut self,
_pgb: &mut PostgresBackend,
_query_string: &str,
) -> Result<(), QueryError> {
panic!()
}
}
let mut handler = TestHandler;
let cfg = rustls::ServerConfig::builder()
.with_safe_defaults()
.with_no_client_auth()
.with_single_cert(vec![CERT.clone()], KEY.clone())
.unwrap();
let tls_config = Some(Arc::new(cfg));
let pgb = PostgresBackend::new(server_sock, AuthType::Trust, tls_config, true).unwrap();
let res = pgb.run(&mut handler).unwrap_err();
assert_eq!("client did not connect with TLS", format!("{}", res));
client_jh.join().unwrap();
// TODO consider shutdown behavior
}

4
libs/walproposer/.gitignore vendored Normal file
View File

@@ -0,0 +1,4 @@
*.a
*.o
*.tmp
pgdata

View File

@@ -0,0 +1,39 @@
[package]
name = "walproposer"
version = "0.1.0"
edition.workspace = true
license.workspace = true
[dependencies]
atty.workspace = true
rand.workspace = true
regex.workspace = true
bytes.workspace = true
byteorder.workspace = true
anyhow.workspace = true
crc32c.workspace = true
hex.workspace = true
once_cell.workspace = true
log.workspace = true
libc.workspace = true
memoffset.workspace = true
thiserror.workspace = true
tracing.workspace = true
tracing-subscriber = { workspace = true, features = ["json"] }
serde.workspace = true
scopeguard.workspace = true
utils.workspace = true
safekeeper.workspace = true
postgres_ffi.workspace = true
hyper.workspace = true
workspace_hack.workspace = true
[dev-dependencies]
env_logger.workspace = true
postgres.workspace = true
[build-dependencies]
anyhow.workspace = true
bindgen.workspace = true
cbindgen = "0.24.0"

View File

@@ -0,0 +1,16 @@
# walproposer Rust module
## Rust -> C
We compile walproposer as a static library and generate Rust bindings for it using `bindgen`.
Entrypoint header file is `bindgen_deps.h`.
## C -> Rust
We use `cbindgen` to generate C bindings for the Rust code. They are stored in `rust_bindings.h`.
## How to run the tests
```
export RUSTFLAGS="-C default-linker-libraries"
```

View File

@@ -0,0 +1,30 @@
/*
* This header file is the input to bindgen. It includes all the
* PostgreSQL headers that we need to auto-generate Rust structs
* from. If you need to expose a new struct to Rust code, add the
* header here, and whitelist the struct in the build.rs file.
*/
#include "c.h"
#include "walproposer.h"
#include <stdarg.h>
#include <stdbool.h>
#include <stdint.h>
#include <stdlib.h>
// Calc a sum of two numbers. Used to test Rust->C function calls.
int TestFunc(int a, int b);
// Run a client for simple simlib test.
void RunClientC(uint32_t serverId);
void WalProposerRust();
void WalProposerCleanup();
extern bool debug_enabled;
// Initialize global variables before calling any Postgres C code.
void MyContextInit();
XLogRecPtr MyInsertRecord();

137
libs/walproposer/build.rs Normal file
View File

@@ -0,0 +1,137 @@
use std::{env, path::PathBuf, process::Command};
use anyhow::{anyhow, Context};
use bindgen::CargoCallbacks;
extern crate bindgen;
fn main() -> anyhow::Result<()> {
let crate_dir = env::var("CARGO_MANIFEST_DIR").unwrap();
cbindgen::Builder::new()
.with_crate(crate_dir)
.with_language(cbindgen::Language::C)
.generate()
.expect("Unable to generate bindings")
.write_to_file("rust_bindings.h");
// Tell cargo to invalidate the built crate whenever the wrapper changes
println!("cargo:rerun-if-changed=bindgen_deps.h,test.c,../../pgxn/neon/walproposer.c,build.sh");
println!("cargo:rustc-link-arg=-Wl,--start-group");
println!("cargo:rustc-link-arg=-lsim");
println!("cargo:rustc-link-arg=-lpgport_srv");
println!("cargo:rustc-link-arg=-lpostgres");
println!("cargo:rustc-link-arg=-lpgcommon_srv");
println!("cargo:rustc-link-arg=-lssl");
println!("cargo:rustc-link-arg=-lcrypto");
println!("cargo:rustc-link-arg=-lz");
println!("cargo:rustc-link-arg=-lpthread");
println!("cargo:rustc-link-arg=-lrt");
println!("cargo:rustc-link-arg=-ldl");
println!("cargo:rustc-link-arg=-lm");
println!("cargo:rustc-link-arg=-lwalproposer");
println!("cargo:rustc-link-arg=-Wl,--end-group");
println!("cargo:rustc-link-search=/home/admin/simulator/libs/walproposer");
// disable fPIE
println!("cargo:rustc-link-arg=-no-pie");
// print output of build.sh
let output = std::process::Command::new("./build.sh")
.output()
.expect("could not spawn `clang`");
println!("stdout: {}", String::from_utf8(output.stdout).unwrap());
println!("stderr: {}", String::from_utf8(output.stderr).unwrap());
if !output.status.success() {
// Panic if the command was not successful.
panic!("could not compile object file");
}
// // Finding the location of C headers for the Postgres server:
// // - if POSTGRES_INSTALL_DIR is set look into it, otherwise look into `<project_root>/pg_install`
// // - if there's a `bin/pg_config` file use it for getting include server, otherwise use `<project_root>/pg_install/{PG_MAJORVERSION}/include/postgresql/server`
let pg_install_dir = if let Some(postgres_install_dir) = env::var_os("POSTGRES_INSTALL_DIR") {
postgres_install_dir.into()
} else {
PathBuf::from("pg_install")
};
let pg_version = "v15";
let mut pg_install_dir_versioned = pg_install_dir.join(pg_version);
if pg_install_dir_versioned.is_relative() {
let cwd = env::current_dir().context("Failed to get current_dir")?;
pg_install_dir_versioned = cwd.join("..").join("..").join(pg_install_dir_versioned);
}
let pg_config_bin = pg_install_dir_versioned
.join(pg_version)
.join("bin")
.join("pg_config");
let inc_server_path: String = if pg_config_bin.exists() {
let output = Command::new(pg_config_bin)
.arg("--includedir-server")
.output()
.context("failed to execute `pg_config --includedir-server`")?;
if !output.status.success() {
panic!("`pg_config --includedir-server` failed")
}
String::from_utf8(output.stdout)
.context("pg_config output is not UTF-8")?
.trim_end()
.into()
} else {
let server_path = pg_install_dir_versioned
.join("include")
.join("postgresql")
.join("server")
.into_os_string();
server_path
.into_string()
.map_err(|s| anyhow!("Bad postgres server path {s:?}"))?
};
let inc_pgxn_path = "/home/admin/simulator/pgxn/neon";
// The bindgen::Builder is the main entry point
// to bindgen, and lets you build up options for
// the resulting bindings.
let bindings = bindgen::Builder::default()
// The input header we would like to generate
// bindings for.
.header("bindgen_deps.h")
// Tell cargo to invalidate the built crate whenever any of the
// included header files changed.
.parse_callbacks(Box::new(CargoCallbacks))
.allowlist_function("TestFunc")
.allowlist_function("RunClientC")
.allowlist_function("WalProposerRust")
.allowlist_function("MyContextInit")
.allowlist_function("WalProposerCleanup")
.allowlist_function("MyInsertRecord")
.allowlist_var("wal_acceptors_list")
.allowlist_var("wal_acceptor_reconnect_timeout")
.allowlist_var("wal_acceptor_connection_timeout")
.allowlist_var("am_wal_proposer")
.allowlist_var("neon_timeline_walproposer")
.allowlist_var("neon_tenant_walproposer")
.allowlist_var("syncSafekeepers")
.allowlist_var("sim_redo_start_lsn")
.allowlist_var("debug_enabled")
.clang_arg(format!("-I{inc_server_path}"))
.clang_arg(format!("-I{inc_pgxn_path}"))
.clang_arg(format!("-DSIMLIB"))
// Finish the builder and generate the bindings.
.generate()
// Unwrap the Result and panic on failure.
.expect("Unable to generate bindings");
// Write the bindings to the $OUT_DIR/bindings.rs file.
let out_path = PathBuf::from(env::var("OUT_DIR").unwrap()).join("bindings.rs");
bindings
.write_to_file(out_path)
.expect("Couldn't write bindings!");
Ok(())
}

21
libs/walproposer/build.sh Executable file
View File

@@ -0,0 +1,21 @@
#!/bin/bash
set -e
cd /home/admin/simulator/libs/walproposer
# TODO: rewrite to Makefile
make -C ../.. neon-pg-ext-walproposer
make -C ../../pg_install/build/v15/src/backend postgres-lib -s
cp ../../pg_install/build/v15/src/backend/libpostgres.a .
cp ../../pg_install/build/v15/src/common/libpgcommon_srv.a .
cp ../../pg_install/build/v15/src/port/libpgport_srv.a .
clang -g -c libpqwalproposer.c test.c -ferror-limit=1 -I ../../pg_install/v15/include/postgresql/server -I ../../pgxn/neon
rm -rf libsim.a
ar rcs libsim.a test.o libpqwalproposer.o
rm -rf libwalproposer.a
PGXN_DIR=../../pg_install/build/neon-v15/
ar rcs libwalproposer.a $PGXN_DIR/walproposer.o $PGXN_DIR/walproposer_utils.o $PGXN_DIR/neon.o

View File

@@ -0,0 +1,542 @@
#include "postgres.h"
#include "neon.h"
#include "walproposer.h"
#include "rust_bindings.h"
#include "replication/message.h"
#include "access/xlog_internal.h"
// defined in walproposer.h
uint64 sim_redo_start_lsn;
XLogRecPtr sim_latest_available_lsn;
/* Header in walproposer.h -- Wrapper struct to abstract away the libpq connection */
struct WalProposerConn
{
int64_t tcp;
};
/* Helper function */
static bool
ensure_nonblocking_status(WalProposerConn *conn, bool is_nonblocking)
{
// walprop_log(LOG, "not implemented");
return false;
}
/* Exported function definitions */
char *
walprop_error_message(WalProposerConn *conn)
{
// walprop_log(LOG, "not implemented");
return NULL;
}
WalProposerConnStatusType
walprop_status(WalProposerConn *conn)
{
// walprop_log(LOG, "not implemented: walprop_status");
return WP_CONNECTION_OK;
}
WalProposerConn *
walprop_connect_start(char *conninfo)
{
WalProposerConn *conn;
walprop_log(LOG, "walprop_connect_start: %s", conninfo);
const char *connstr_prefix = "host=node port=";
Assert(strncmp(conninfo, connstr_prefix, strlen(connstr_prefix)) == 0);
int nodeId = atoi(conninfo + strlen(connstr_prefix));
conn = palloc(sizeof(WalProposerConn));
conn->tcp = sim_open_tcp(nodeId);
return conn;
}
WalProposerConnectPollStatusType
walprop_connect_poll(WalProposerConn *conn)
{
// walprop_log(LOG, "not implemented: walprop_connect_poll");
return WP_CONN_POLLING_OK;
}
bool
walprop_send_query(WalProposerConn *conn, char *query)
{
// walprop_log(LOG, "not implemented: walprop_send_query");
return true;
}
WalProposerExecStatusType
walprop_get_query_result(WalProposerConn *conn)
{
// walprop_log(LOG, "not implemented: walprop_get_query_result");
return WP_EXEC_SUCCESS_COPYBOTH;
}
pgsocket
walprop_socket(WalProposerConn *conn)
{
return (pgsocket) conn->tcp;
}
int
walprop_flush(WalProposerConn *conn)
{
// walprop_log(LOG, "not implemented");
return 0;
}
void
walprop_finish(WalProposerConn *conn)
{
// walprop_log(LOG, "walprop_finish not implemented");
}
/*
* Receive a message from the safekeeper.
*
* On success, the data is placed in *buf. It is valid until the next call
* to this function.
*/
PGAsyncReadResult
walprop_async_read(WalProposerConn *conn, char **buf, int *amount)
{
uintptr_t len;
char *msg;
Event event;
event = sim_epoll_peek(0);
if (event.tcp != conn->tcp || event.tag != Message || event.any_message != Bytes)
return PG_ASYNC_READ_TRY_AGAIN;
event = sim_epoll_rcv(0);
// walprop_log(LOG, "walprop_async_read, T: %d, tcp: %d, tag: %d", (int) event.tag, (int) event.tcp, (int) event.any_message);
Assert(event.tcp == conn->tcp);
Assert(event.tag == Message);
Assert(event.any_message == Bytes);
msg = (char*) sim_msg_get_bytes(&len);
*buf = msg;
*amount = len;
// walprop_log(LOG, "walprop_async_read: %d", (int) len);
return PG_ASYNC_READ_SUCCESS;
}
PGAsyncWriteResult
walprop_async_write(WalProposerConn *conn, void const *buf, size_t size)
{
// walprop_log(LOG, "walprop_async_write");
sim_msg_set_bytes(buf, size);
sim_tcp_send(conn->tcp);
return PG_ASYNC_WRITE_SUCCESS;
}
/*
* This function is very similar to walprop_async_write. For more
* information, refer to the comments there.
*/
bool
walprop_blocking_write(WalProposerConn *conn, void const *buf, size_t size)
{
// walprop_log(LOG, "walprop_blocking_write");
sim_msg_set_bytes(buf, size);
sim_tcp_send(conn->tcp);
return true;
}
void
sim_start_replication(XLogRecPtr startptr)
{
walprop_log(LOG, "sim_start_replication: %X/%X", LSN_FORMAT_ARGS(startptr));
sim_latest_available_lsn = startptr;
for (;;)
{
XLogRecPtr endptr = sim_latest_available_lsn;
Assert(startptr <= endptr);
if (endptr > startptr)
{
WalProposerBroadcast(startptr, endptr);
startptr = endptr;
}
WalProposerPoll();
}
}
#define UsableBytesInPage (XLOG_BLCKSZ - SizeOfXLogShortPHD)
static int UsableBytesInSegment =
(DEFAULT_XLOG_SEG_SIZE / XLOG_BLCKSZ * UsableBytesInPage) -
(SizeOfXLogLongPHD - SizeOfXLogShortPHD);
/*
* Converts a "usable byte position" to XLogRecPtr. A usable byte position
* is the position starting from the beginning of WAL, excluding all WAL
* page headers.
*/
static XLogRecPtr
XLogBytePosToRecPtr(uint64 bytepos)
{
uint64 fullsegs;
uint64 fullpages;
uint64 bytesleft;
uint32 seg_offset;
XLogRecPtr result;
fullsegs = bytepos / UsableBytesInSegment;
bytesleft = bytepos % UsableBytesInSegment;
if (bytesleft < XLOG_BLCKSZ - SizeOfXLogLongPHD)
{
/* fits on first page of segment */
seg_offset = bytesleft + SizeOfXLogLongPHD;
}
else
{
/* account for the first page on segment with long header */
seg_offset = XLOG_BLCKSZ;
bytesleft -= XLOG_BLCKSZ - SizeOfXLogLongPHD;
fullpages = bytesleft / UsableBytesInPage;
bytesleft = bytesleft % UsableBytesInPage;
seg_offset += fullpages * XLOG_BLCKSZ + bytesleft + SizeOfXLogShortPHD;
}
XLogSegNoOffsetToRecPtr(fullsegs, seg_offset, wal_segment_size, result);
return result;
}
/*
* Convert an XLogRecPtr to a "usable byte position".
*/
static uint64
XLogRecPtrToBytePos(XLogRecPtr ptr)
{
uint64 fullsegs;
uint32 fullpages;
uint32 offset;
uint64 result;
XLByteToSeg(ptr, fullsegs, wal_segment_size);
fullpages = (XLogSegmentOffset(ptr, wal_segment_size)) / XLOG_BLCKSZ;
offset = ptr % XLOG_BLCKSZ;
if (fullpages == 0)
{
result = fullsegs * UsableBytesInSegment;
if (offset > 0)
{
Assert(offset >= SizeOfXLogLongPHD);
result += offset - SizeOfXLogLongPHD;
}
}
else
{
result = fullsegs * UsableBytesInSegment +
(XLOG_BLCKSZ - SizeOfXLogLongPHD) + /* account for first page */
(fullpages - 1) * UsableBytesInPage; /* full pages */
if (offset > 0)
{
Assert(offset >= SizeOfXLogShortPHD);
result += offset - SizeOfXLogShortPHD;
}
}
return result;
}
#define max_rdatas 16
void InitMyInsert();
static void MyBeginInsert();
static void MyRegisterData(char *data, int len);
static XLogRecPtr MyFinishInsert(RmgrId rmid, uint8 info, uint8 flags);
static void MyCopyXLogRecordToWAL(int write_len, XLogRecData *rdata, XLogRecPtr StartPos, XLogRecPtr EndPos);
/*
* An array of XLogRecData structs, to hold registered data.
*/
static XLogRecData rdatas[max_rdatas];
static int num_rdatas; /* entries currently used */
static uint32 mainrdata_len; /* total # of bytes in chain */
static XLogRecData hdr_rdt;
static char hdr_scratch[16000];
static XLogRecPtr CurrBytePos;
static XLogRecPtr PrevBytePos;
void InitMyInsert()
{
CurrBytePos = sim_redo_start_lsn;
PrevBytePos = InvalidXLogRecPtr;
sim_latest_available_lsn = sim_redo_start_lsn;
}
static void MyBeginInsert()
{
num_rdatas = 0;
mainrdata_len = 0;
}
static void MyRegisterData(char *data, int len)
{
XLogRecData *rdata;
if (num_rdatas >= max_rdatas)
walprop_log(ERROR, "too much WAL data");
rdata = &rdatas[num_rdatas++];
rdata->data = data;
rdata->len = len;
rdata->next = NULL;
if (num_rdatas > 1) {
rdatas[num_rdatas - 2].next = rdata;
}
mainrdata_len += len;
}
static XLogRecPtr
MyFinishInsert(RmgrId rmid, uint8 info, uint8 flags)
{
XLogRecData *rdt;
uint32 total_len = 0;
int block_id;
pg_crc32c rdata_crc;
XLogRecord *rechdr;
char *scratch = hdr_scratch;
int size;
XLogRecPtr StartPos;
XLogRecPtr EndPos;
uint64 startbytepos;
uint64 endbytepos;
/*
* Note: this function can be called multiple times for the same record.
* All the modifications we do to the rdata chains below must handle that.
*/
/* The record begins with the fixed-size header */
rechdr = (XLogRecord *) scratch;
scratch += SizeOfXLogRecord;
hdr_rdt.data = hdr_scratch;
if (num_rdatas > 0)
{
hdr_rdt.next = &rdatas[0];
}
else
{
hdr_rdt.next = NULL;
}
/* followed by main data, if any */
if (mainrdata_len > 0)
{
if (mainrdata_len > 255)
{
*(scratch++) = (char) XLR_BLOCK_ID_DATA_LONG;
memcpy(scratch, &mainrdata_len, sizeof(uint32));
scratch += sizeof(uint32);
}
else
{
*(scratch++) = (char) XLR_BLOCK_ID_DATA_SHORT;
*(scratch++) = (uint8) mainrdata_len;
}
total_len += mainrdata_len;
}
hdr_rdt.len = (scratch - hdr_scratch);
total_len += hdr_rdt.len;
/*
* Calculate CRC of the data
*
* Note that the record header isn't added into the CRC initially since we
* don't know the prev-link yet. Thus, the CRC will represent the CRC of
* the whole record in the order: rdata, then backup blocks, then record
* header.
*/
INIT_CRC32C(rdata_crc);
COMP_CRC32C(rdata_crc, hdr_scratch + SizeOfXLogRecord, hdr_rdt.len - SizeOfXLogRecord);
for (size_t i = 0; i < num_rdatas; i++)
{
rdt = &rdatas[i];
COMP_CRC32C(rdata_crc, rdt->data, rdt->len);
}
/*
* Fill in the fields in the record header. Prev-link is filled in later,
* once we know where in the WAL the record will be inserted. The CRC does
* not include the record header yet.
*/
rechdr->xl_xid = 0;
rechdr->xl_tot_len = total_len;
rechdr->xl_info = info;
rechdr->xl_rmid = rmid;
rechdr->xl_prev = InvalidXLogRecPtr;
rechdr->xl_crc = rdata_crc;
size = MAXALIGN(rechdr->xl_tot_len);
/* All (non xlog-switch) records should contain data. */
Assert(size > SizeOfXLogRecord);
startbytepos = XLogRecPtrToBytePos(CurrBytePos);
endbytepos = startbytepos + size;
// Get the position.
StartPos = XLogBytePosToRecPtr(startbytepos);
EndPos = XLogBytePosToRecPtr(startbytepos + size);
rechdr->xl_prev = PrevBytePos;
Assert(XLogRecPtrToBytePos(StartPos) == startbytepos);
Assert(XLogRecPtrToBytePos(EndPos) == endbytepos);
// Update global pointers.
CurrBytePos = EndPos;
PrevBytePos = StartPos;
/*
* Now that xl_prev has been filled in, calculate CRC of the record
* header.
*/
rdata_crc = rechdr->xl_crc;
COMP_CRC32C(rdata_crc, rechdr, offsetof(XLogRecord, xl_crc));
FIN_CRC32C(rdata_crc);
rechdr->xl_crc = rdata_crc;
// Now write it to disk.
MyCopyXLogRecordToWAL(rechdr->xl_tot_len, &hdr_rdt, StartPos, EndPos);
return EndPos;
}
#define INSERT_FREESPACE(endptr) \
(((endptr) % XLOG_BLCKSZ == 0) ? 0 : (XLOG_BLCKSZ - (endptr) % XLOG_BLCKSZ))
static void
MyCopyXLogRecordToWAL(int write_len, XLogRecData *rdata, XLogRecPtr StartPos, XLogRecPtr EndPos)
{
XLogRecPtr CurrPos;
int written;
int freespace;
// Write hdr_rdt and `num_rdatas` other datas.
CurrPos = StartPos;
freespace = INSERT_FREESPACE(CurrPos);
written = 0;
Assert(freespace >= sizeof(uint32));
while (rdata != NULL)
{
char *rdata_data = rdata->data;
int rdata_len = rdata->len;
while (rdata_len >= freespace)
{
char header_buf[SizeOfXLogLongPHD];
XLogPageHeader NewPage = (XLogPageHeader) header_buf;
Assert(CurrPos % XLOG_BLCKSZ >= SizeOfXLogShortPHD || freespace == 0);
XLogWalPropWrite(rdata_data, freespace, CurrPos);
rdata_data += freespace;
rdata_len -= freespace;
written += freespace;
CurrPos += freespace;
// Init new page
MemSet(header_buf, 0, SizeOfXLogLongPHD);
/*
* Fill the new page's header
*/
NewPage->xlp_magic = XLOG_PAGE_MAGIC;
/* NewPage->xlp_info = 0; */ /* done by memset */
NewPage->xlp_tli = 1;
NewPage->xlp_pageaddr = CurrPos;
/* NewPage->xlp_rem_len = 0; */ /* done by memset */
NewPage->xlp_info |= XLP_BKP_REMOVABLE;
/*
* If first page of an XLOG segment file, make it a long header.
*/
if ((XLogSegmentOffset(NewPage->xlp_pageaddr, wal_segment_size)) == 0)
{
XLogLongPageHeader NewLongPage = (XLogLongPageHeader) NewPage;
NewLongPage->xlp_sysid = 0;
NewLongPage->xlp_seg_size = wal_segment_size;
NewLongPage->xlp_xlog_blcksz = XLOG_BLCKSZ;
NewPage->xlp_info |= XLP_LONG_HEADER;
}
NewPage->xlp_rem_len = write_len - written;
if (NewPage->xlp_rem_len > 0) {
NewPage->xlp_info |= XLP_FIRST_IS_CONTRECORD;
}
/* skip over the page header */
if (XLogSegmentOffset(CurrPos, wal_segment_size) == 0)
{
XLogWalPropWrite(header_buf, SizeOfXLogLongPHD, CurrPos);
CurrPos += SizeOfXLogLongPHD;
}
else
{
XLogWalPropWrite(header_buf, SizeOfXLogShortPHD, CurrPos);
CurrPos += SizeOfXLogShortPHD;
}
freespace = INSERT_FREESPACE(CurrPos);
}
Assert(CurrPos % XLOG_BLCKSZ >= SizeOfXLogShortPHD || rdata_len == 0);
XLogWalPropWrite(rdata_data, rdata_len, CurrPos);
CurrPos += rdata_len;
written += rdata_len;
freespace -= rdata_len;
rdata = rdata->next;
}
Assert(written == write_len);
CurrPos = MAXALIGN64(CurrPos);
Assert(CurrPos == EndPos);
}
XLogRecPtr MyInsertRecord()
{
const char *prefix = "prefix";
const char *message = "message";
size_t size = 7;
bool transactional = false;
xl_logical_message xlrec;
xlrec.dbId = 0;
xlrec.transactional = transactional;
/* trailing zero is critical; see logicalmsg_desc */
xlrec.prefix_size = strlen(prefix) + 1;
xlrec.message_size = size;
MyBeginInsert();
MyRegisterData((char *) &xlrec, SizeOfLogicalMessage);
MyRegisterData(unconstify(char *, prefix), xlrec.prefix_size);
MyRegisterData(unconstify(char *, message), size);
return MyFinishInsert(RM_LOGICALMSG_ID, XLOG_LOGICAL_MESSAGE, XLOG_INCLUDE_ORIGIN);
}

View File

@@ -0,0 +1,106 @@
#include <stdarg.h>
#include <stdbool.h>
#include <stdint.h>
#include <stdlib.h>
/**
* List of all possible AnyMessage.
*/
enum AnyMessageTag {
None,
InternalConnect,
Just32,
ReplCell,
Bytes,
LSN,
};
typedef uint8_t AnyMessageTag;
/**
* List of all possible NodeEvent.
*/
enum EventTag {
Timeout,
Accept,
Closed,
Message,
Internal,
};
typedef uint8_t EventTag;
/**
* Event returned by epoll_recv.
*/
typedef struct Event {
EventTag tag;
int64_t tcp;
AnyMessageTag any_message;
} Event;
void rust_function(uint32_t a);
/**
* C API for the node os.
*/
void sim_sleep(uint64_t ms);
uint64_t sim_random(uint64_t max);
uint32_t sim_id(void);
int64_t sim_open_tcp(uint32_t dst);
int64_t sim_open_tcp_nopoll(uint32_t dst);
/**
* Send MESSAGE_BUF content to the given tcp.
*/
void sim_tcp_send(int64_t tcp);
/**
* Receive a message from the given tcp. Can be used only with tcp opened with
* `sim_open_tcp_nopoll`.
*/
struct Event sim_tcp_recv(int64_t tcp);
struct Event sim_epoll_rcv(int64_t timeout);
struct Event sim_epoll_peek(int64_t timeout);
int64_t sim_now(void);
void sim_exit(int32_t code, const uint8_t *msg);
void sim_set_result(int32_t code, const uint8_t *msg);
void sim_log_event(const int8_t *msg);
/**
* Get tag of the current message.
*/
AnyMessageTag sim_msg_tag(void);
/**
* Read AnyMessage::Just32 message.
*/
void sim_msg_get_just_u32(uint32_t *val);
/**
* Read AnyMessage::LSN message.
*/
void sim_msg_get_lsn(uint64_t *val);
/**
* Write AnyMessage::ReplCell message.
*/
void sim_msg_set_repl_cell(uint32_t value, uint32_t client_id, uint32_t seqno);
/**
* Write AnyMessage::Bytes message.
*/
void sim_msg_set_bytes(const char *bytes, uintptr_t len);
/**
* Read AnyMessage::Bytes message.
*/
const char *sim_msg_get_bytes(uintptr_t *len);

View File

@@ -0,0 +1,36 @@
#![allow(non_upper_case_globals)]
#![allow(non_camel_case_types)]
#![allow(non_snake_case)]
use safekeeper::simlib::node_os::NodeOs;
use tracing::info;
pub mod bindings {
include!(concat!(env!("OUT_DIR"), "/bindings.rs"));
}
#[no_mangle]
pub extern "C" fn rust_function(a: u32) {
info!("Hello from Rust!");
info!("a: {}", a);
}
pub mod sim;
pub mod sim_proto;
#[cfg(test)]
mod test;
#[cfg(test)]
pub mod simtest;
pub fn c_context() -> Option<Box<dyn Fn(NodeOs) + Send + Sync>> {
Some(Box::new(|os: NodeOs| {
sim::c_attach_node_os(os);
unsafe { bindings::MyContextInit(); }
}))
}
pub fn enable_debug() {
unsafe { bindings::debug_enabled = true; }
}

240
libs/walproposer/src/sim.rs Normal file
View File

@@ -0,0 +1,240 @@
use log::debug;
use safekeeper::simlib::{network::TCP, node_os::NodeOs, world::NodeEvent};
use std::{
cell::RefCell,
collections::HashMap,
ffi::{CStr, CString},
};
use tracing::trace;
use crate::sim_proto::{anymessage_tag, AnyMessageTag, Event, EventTag, MESSAGE_BUF};
thread_local! {
static CURRENT_NODE_OS: RefCell<Option<NodeOs>> = RefCell::new(None);
static TCP_CACHE: RefCell<HashMap<i64, TCP>> = RefCell::new(HashMap::new());
}
/// Get the current node os.
fn os() -> NodeOs {
CURRENT_NODE_OS.with(|cell| cell.borrow().clone().expect("no node os set"))
}
fn tcp_save(tcp: TCP) -> i64 {
TCP_CACHE.with(|cell| {
let mut cache = cell.borrow_mut();
let id = tcp.id();
cache.insert(id, tcp);
id
})
}
fn tcp_load(id: i64) -> TCP {
TCP_CACHE.with(|cell| {
let cache = cell.borrow();
cache.get(&id).expect("unknown TCP id").clone()
})
}
/// Should be called before calling any of the C functions.
pub(crate) fn c_attach_node_os(os: NodeOs) {
CURRENT_NODE_OS.with(|cell| {
*cell.borrow_mut() = Some(os);
});
TCP_CACHE.with(|cell| {
*cell.borrow_mut() = HashMap::new();
});
}
/// C API for the node os.
#[no_mangle]
pub extern "C" fn sim_sleep(ms: u64) {
os().sleep(ms);
}
#[no_mangle]
pub extern "C" fn sim_random(max: u64) -> u64 {
os().random(max)
}
#[no_mangle]
pub extern "C" fn sim_id() -> u32 {
os().id().into()
}
#[no_mangle]
pub extern "C" fn sim_open_tcp(dst: u32) -> i64 {
tcp_save(os().open_tcp(dst.into()))
}
#[no_mangle]
pub extern "C" fn sim_open_tcp_nopoll(dst: u32) -> i64 {
tcp_save(os().open_tcp_nopoll(dst.into()))
}
#[no_mangle]
/// Send MESSAGE_BUF content to the given tcp.
pub extern "C" fn sim_tcp_send(tcp: i64) {
tcp_load(tcp).send(MESSAGE_BUF.with(|cell| cell.borrow().clone()));
}
#[no_mangle]
/// Receive a message from the given tcp. Can be used only with tcp opened with
/// `sim_open_tcp_nopoll`.
pub extern "C" fn sim_tcp_recv(tcp: i64) -> Event {
let event = tcp_load(tcp).recv();
match event {
NodeEvent::Accept(_) => unreachable!(),
NodeEvent::Closed(_) => Event {
tag: EventTag::Closed,
tcp: 0,
any_message: AnyMessageTag::None,
},
NodeEvent::Internal(_) => unreachable!(),
NodeEvent::Message((message, _)) => {
// store message in thread local storage, C code should use
// sim_msg_* functions to access it.
MESSAGE_BUF.with(|cell| {
*cell.borrow_mut() = message.clone();
});
Event {
tag: EventTag::Message,
tcp: 0,
any_message: anymessage_tag(&message),
}
}
NodeEvent::WakeTimeout(_) => unreachable!(),
}
}
#[no_mangle]
pub extern "C" fn sim_epoll_rcv(timeout: i64) -> Event {
let event = os().epoll_recv(timeout);
let event = if let Some(event) = event {
event
} else {
return Event {
tag: EventTag::Timeout,
tcp: 0,
any_message: AnyMessageTag::None,
};
};
match event {
NodeEvent::Accept(tcp) => Event {
tag: EventTag::Accept,
tcp: tcp_save(tcp),
any_message: AnyMessageTag::None,
},
NodeEvent::Closed(tcp) => Event {
tag: EventTag::Closed,
tcp: tcp_save(tcp),
any_message: AnyMessageTag::None,
},
NodeEvent::Message((message, tcp)) => {
// store message in thread local storage, C code should use
// sim_msg_* functions to access it.
MESSAGE_BUF.with(|cell| {
*cell.borrow_mut() = message.clone();
});
Event {
tag: EventTag::Message,
tcp: tcp_save(tcp),
any_message: anymessage_tag(&message),
}
}
NodeEvent::Internal(message) => {
// store message in thread local storage, C code should use
// sim_msg_* functions to access it.
MESSAGE_BUF.with(|cell| {
*cell.borrow_mut() = message.clone();
});
Event {
tag: EventTag::Internal,
tcp: 0,
any_message: anymessage_tag(&message),
}
}
NodeEvent::WakeTimeout(_) => {
// can't happen
unreachable!()
}
}
}
#[no_mangle]
pub extern "C" fn sim_epoll_peek(timeout: i64) -> Event {
let event = os().epoll_peek(timeout);
let event = if let Some(event) = event {
event
} else {
return Event {
tag: EventTag::Timeout,
tcp: 0,
any_message: AnyMessageTag::None,
};
};
match event {
NodeEvent::Accept(tcp) => Event {
tag: EventTag::Accept,
tcp: tcp_save(tcp),
any_message: AnyMessageTag::None,
},
NodeEvent::Closed(tcp) => Event {
tag: EventTag::Closed,
tcp: tcp_save(tcp),
any_message: AnyMessageTag::None,
},
NodeEvent::Message((message, tcp)) => Event {
tag: EventTag::Message,
tcp: tcp_save(tcp),
any_message: anymessage_tag(&message),
},
NodeEvent::Internal(message) => Event {
tag: EventTag::Internal,
tcp: 0,
any_message: anymessage_tag(&message),
},
NodeEvent::WakeTimeout(_) => {
// can't happen
unreachable!()
}
}
}
#[no_mangle]
pub extern "C" fn sim_now() -> i64 {
os().now() as i64
}
#[no_mangle]
pub extern "C" fn sim_exit(code: i32, msg: *const u8) {
trace!("sim_exit({}, {:?})", code, msg);
sim_set_result(code, msg);
// I tried to make use of pthread_exit, but it doesn't work.
// https://github.com/rust-lang/unsafe-code-guidelines/issues/211
// unsafe { libc::pthread_exit(std::ptr::null_mut()) };
// https://doc.rust-lang.org/nomicon/unwinding.html
// Everyone on the internet saying this is UB, but it works for me,
// so I'm going to use it for now.
panic!("sim_exit() called from C code")
}
#[no_mangle]
pub extern "C" fn sim_set_result(code: i32, msg: *const u8) {
let msg = unsafe { CStr::from_ptr(msg as *const i8) };
let msg = msg.to_string_lossy().into_owned();
debug!("sim_set_result({}, {:?})", code, msg);
os().set_result(code, msg);
}
#[no_mangle]
pub extern "C" fn sim_log_event(msg: *const i8) {
let msg = unsafe { CStr::from_ptr(msg) };
let msg = msg.to_string_lossy().into_owned();
debug!("sim_log_event({:?})", msg);
os().log_event(msg);
}

View File

@@ -0,0 +1,114 @@
use safekeeper::simlib::proto::{AnyMessage, ReplCell};
use std::{cell::RefCell, ffi::c_char};
pub(crate) fn anymessage_tag(msg: &AnyMessage) -> AnyMessageTag {
match msg {
AnyMessage::None => AnyMessageTag::None,
AnyMessage::InternalConnect => AnyMessageTag::InternalConnect,
AnyMessage::Just32(_) => AnyMessageTag::Just32,
AnyMessage::ReplCell(_) => AnyMessageTag::ReplCell,
AnyMessage::Bytes(_) => AnyMessageTag::Bytes,
AnyMessage::LSN(_) => AnyMessageTag::LSN,
}
}
thread_local! {
pub static MESSAGE_BUF: RefCell<AnyMessage> = RefCell::new(AnyMessage::None);
}
#[no_mangle]
/// Get tag of the current message.
pub extern "C" fn sim_msg_tag() -> AnyMessageTag {
MESSAGE_BUF.with(|cell| anymessage_tag(&*cell.borrow()))
}
#[no_mangle]
/// Read AnyMessage::Just32 message.
pub extern "C" fn sim_msg_get_just_u32(val: &mut u32) {
MESSAGE_BUF.with(|cell| match &*cell.borrow() {
AnyMessage::Just32(v) => {
*val = *v;
}
_ => panic!("expected Just32 message"),
});
}
#[no_mangle]
/// Read AnyMessage::LSN message.
pub extern "C" fn sim_msg_get_lsn(val: &mut u64) {
MESSAGE_BUF.with(|cell| match &*cell.borrow() {
AnyMessage::LSN(v) => {
*val = *v;
}
_ => panic!("expected LSN message"),
});
}
#[no_mangle]
/// Write AnyMessage::ReplCell message.
pub extern "C" fn sim_msg_set_repl_cell(value: u32, client_id: u32, seqno: u32) {
MESSAGE_BUF.with(|cell| {
*cell.borrow_mut() = AnyMessage::ReplCell(ReplCell {
value,
client_id,
seqno,
});
});
}
#[no_mangle]
/// Write AnyMessage::Bytes message.
pub extern "C" fn sim_msg_set_bytes(bytes: *const c_char, len: usize) {
MESSAGE_BUF.with(|cell| {
// copy bytes to a Rust Vec
let mut v: Vec<u8> = Vec::with_capacity(len);
unsafe {
v.set_len(len);
std::ptr::copy_nonoverlapping(bytes as *const u8, v.as_mut_ptr(), len);
}
*cell.borrow_mut() = AnyMessage::Bytes(v.into());
});
}
#[no_mangle]
/// Read AnyMessage::Bytes message.
pub extern "C" fn sim_msg_get_bytes(len: *mut usize) -> *const c_char {
MESSAGE_BUF.with(|cell| match &*cell.borrow() {
AnyMessage::Bytes(v) => {
unsafe {
*len = v.len();
v.as_ptr() as *const i8
}
}
_ => panic!("expected Bytes message"),
})
}
#[repr(C)]
/// Event returned by epoll_recv.
pub struct Event {
pub tag: EventTag,
pub tcp: i64,
pub any_message: AnyMessageTag,
}
#[repr(u8)]
/// List of all possible NodeEvent.
pub enum EventTag {
Timeout,
Accept,
Closed,
Message,
Internal,
}
#[repr(u8)]
/// List of all possible AnyMessage.
pub enum AnyMessageTag {
None,
InternalConnect,
Just32,
ReplCell,
Bytes,
LSN,
}

View File

@@ -0,0 +1,88 @@
use std::collections::HashMap;
use std::sync::Arc;
use safekeeper::safekeeper::SafeKeeperState;
use safekeeper::simlib::sync::Mutex;
use utils::id::TenantTimelineId;
pub struct Disk {
pub timelines: Mutex<HashMap<TenantTimelineId, Arc<TimelineDisk>>>,
}
impl Disk {
pub fn new() -> Self {
Disk {
timelines: Mutex::new(HashMap::new()),
}
}
pub fn put_state(&self, ttid: &TenantTimelineId, state: SafeKeeperState) -> Arc<TimelineDisk> {
self.timelines
.lock()
.entry(ttid.clone())
.and_modify(|e| {
let mut mu = e.state.lock();
*mu = state.clone();
})
.or_insert_with(|| {
Arc::new(TimelineDisk {
state: Mutex::new(state),
wal: Mutex::new(BlockStorage::new()),
})
})
.clone()
}
}
pub struct TimelineDisk {
pub state: Mutex<SafeKeeperState>,
pub wal: Mutex<BlockStorage>,
}
const BLOCK_SIZE: usize = 8192;
pub struct BlockStorage {
blocks: HashMap<u64, [u8; BLOCK_SIZE]>,
}
impl BlockStorage {
pub fn new() -> Self {
BlockStorage {
blocks: HashMap::new(),
}
}
pub fn read(&self, pos: u64, buf: &mut [u8]) {
let mut buf_offset = 0;
let mut storage_pos = pos;
while buf_offset < buf.len() {
let block_id = storage_pos / BLOCK_SIZE as u64;
let block = self.blocks.get(&block_id).unwrap_or(&[0; BLOCK_SIZE]);
let block_offset = storage_pos % BLOCK_SIZE as u64;
let block_len = BLOCK_SIZE as u64 - block_offset;
let buf_len = buf.len() - buf_offset;
let copy_len = std::cmp::min(block_len as usize, buf_len);
buf[buf_offset..buf_offset + copy_len]
.copy_from_slice(&block[block_offset as usize..block_offset as usize + copy_len]);
buf_offset += copy_len;
storage_pos += copy_len as u64;
}
}
pub fn write(&mut self, pos: u64, buf: &[u8]) {
let mut buf_offset = 0;
let mut storage_pos = pos;
while buf_offset < buf.len() {
let block_id = storage_pos / BLOCK_SIZE as u64;
let block = self.blocks.entry(block_id).or_insert([0; BLOCK_SIZE]);
let block_offset = storage_pos % BLOCK_SIZE as u64;
let block_len = BLOCK_SIZE as u64 - block_offset;
let buf_len = buf.len() - buf_offset;
let copy_len = std::cmp::min(block_len as usize, buf_len);
block[block_offset as usize..block_offset as usize + copy_len]
.copy_from_slice(&buf[buf_offset..buf_offset + copy_len]);
buf_offset += copy_len;
storage_pos += copy_len as u64
}
}
}

View File

@@ -0,0 +1,61 @@
use std::{sync::Arc, fmt};
use safekeeper::simlib::{world::World, sync::Mutex};
use tracing_subscriber::fmt::{time::FormatTime, format::Writer};
use utils::logging;
use crate::bindings;
#[derive(Clone)]
pub struct SimClock {
world_ptr: Arc<Mutex<Option<Arc<World>>>>,
}
impl Default for SimClock {
fn default() -> Self {
SimClock {
world_ptr: Arc::new(Mutex::new(None)),
}
}
}
impl SimClock {
pub fn set_world(&self, world: Arc<World>) {
*self.world_ptr.lock() = Some(world);
}
}
impl FormatTime for SimClock {
fn format_time(&self, w: &mut Writer<'_>) -> fmt::Result {
let world = self.world_ptr.lock().clone();
if let Some(world) = world {
let now = world.now();
write!(w, "[{}]", now)
} else {
write!(w, "[?]")
}
}
}
pub fn init_logger() -> SimClock {
let debug_enabled = unsafe { bindings::debug_enabled };
let clock = SimClock::default();
let base_logger = tracing_subscriber::fmt()
.with_target(false)
.with_timer(clock.clone())
.with_ansi(true)
.with_max_level(match debug_enabled {
true => tracing::Level::DEBUG,
false => tracing::Level::INFO,
})
.with_writer(std::io::stdout);
base_logger.init();
// logging::replace_panic_hook_with_tracing_panic_hook().forget();
std::panic::set_hook(Box::new(|_| {}));
clock
}

View File

@@ -0,0 +1,11 @@
#[cfg(test)]
pub mod simple_client;
#[cfg(test)]
pub mod wp_sk;
pub mod disk;
pub mod safekeeper;
pub mod storage;
pub mod log;
pub mod util;

View File

@@ -0,0 +1,372 @@
//! Safekeeper communication endpoint to WAL proposer (compute node).
//! Gets messages from the network, passes them down to consensus module and
//! sends replies back.
use std::{collections::HashMap, path::PathBuf, sync::Arc, time::Duration};
use anyhow::{anyhow, bail, Result};
use bytes::{Bytes, BytesMut};
use hyper::Uri;
use log::info;
use safekeeper::{
safekeeper::{
ProposerAcceptorMessage, SafeKeeper, SafeKeeperState, ServerInfo, UNKNOWN_SERVER_VERSION,
},
simlib::{network::TCP, node_os::NodeOs, proto::AnyMessage, world::NodeEvent},
timeline::TimelineError,
SafeKeeperConf, wal_storage::Storage,
};
use tracing::{debug, info_span};
use utils::{
id::{NodeId, TenantId, TenantTimelineId, TimelineId},
lsn::Lsn,
};
use crate::simtest::storage::DiskStateStorage;
use super::{
disk::{Disk, TimelineDisk},
storage::DiskWALStorage,
};
struct ConnState {
tcp: TCP,
greeting: bool,
ttid: TenantTimelineId,
flush_pending: bool,
}
struct SharedState {
sk: SafeKeeper<DiskStateStorage, DiskWALStorage>,
disk: Arc<TimelineDisk>,
}
struct GlobalMap {
timelines: HashMap<TenantTimelineId, SharedState>,
conf: SafeKeeperConf,
disk: Arc<Disk>,
}
impl GlobalMap {
fn new(disk: Arc<Disk>, conf: SafeKeeperConf) -> Result<Self> {
let mut timelines = HashMap::new();
for (&ttid, disk) in disk.timelines.lock().iter() {
debug!("loading timeline {}", ttid);
let state = disk.state.lock().clone();
if state.server.wal_seg_size == 0 {
bail!(TimelineError::UninitializedWalSegSize(ttid));
}
if state.server.pg_version == UNKNOWN_SERVER_VERSION {
bail!(TimelineError::UninitialinzedPgVersion(ttid));
}
if state.commit_lsn < state.local_start_lsn {
bail!(
"commit_lsn {} is higher than local_start_lsn {}",
state.commit_lsn,
state.local_start_lsn
);
}
let control_store = DiskStateStorage::new(disk.clone());
let wal_store = DiskWALStorage::new(disk.clone(), &control_store)?;
let sk = SafeKeeper::new(control_store, wal_store, conf.my_id)?;
timelines.insert(
ttid.clone(),
SharedState {
sk,
disk: disk.clone(),
},
);
}
Ok(Self {
timelines,
conf,
disk,
})
}
fn create(&mut self, ttid: TenantTimelineId, server_info: ServerInfo) -> Result<()> {
if self.timelines.contains_key(&ttid) {
bail!("timeline {} already exists", ttid);
}
debug!("creating new timeline {}", ttid);
let commit_lsn = Lsn::INVALID;
let local_start_lsn = Lsn::INVALID;
// TODO: load state from in-memory storage
let state = SafeKeeperState::new(&ttid, server_info, vec![], commit_lsn, local_start_lsn);
if state.server.wal_seg_size == 0 {
bail!(TimelineError::UninitializedWalSegSize(ttid));
}
if state.server.pg_version == UNKNOWN_SERVER_VERSION {
bail!(TimelineError::UninitialinzedPgVersion(ttid));
}
if state.commit_lsn < state.local_start_lsn {
bail!(
"commit_lsn {} is higher than local_start_lsn {}",
state.commit_lsn,
state.local_start_lsn
);
}
let disk_timeline = self.disk.put_state(&ttid, state);
let control_store = DiskStateStorage::new(disk_timeline.clone());
let wal_store = DiskWALStorage::new(disk_timeline.clone(), &control_store)?;
let sk = SafeKeeper::new(control_store, wal_store, self.conf.my_id)?;
self.timelines.insert(
ttid.clone(),
SharedState {
sk,
disk: disk_timeline,
},
);
Ok(())
}
fn get(&mut self, ttid: &TenantTimelineId) -> &mut SharedState {
self.timelines.get_mut(ttid).expect("timeline must exist")
}
fn has_tli(&self, ttid: &TenantTimelineId) -> bool {
self.timelines.contains_key(ttid)
}
}
pub fn run_server(os: NodeOs, disk: Arc<Disk>) -> Result<()> {
let _enter = info_span!("safekeeper", id = os.id()).entered();
debug!("started server");
os.log_event("started;safekeeper".to_owned());
let conf = SafeKeeperConf {
workdir: PathBuf::from("."),
my_id: NodeId(os.id() as u64),
listen_pg_addr: String::new(),
listen_http_addr: String::new(),
no_sync: false,
broker_endpoint: "/".parse::<Uri>().unwrap(),
broker_keepalive_interval: Duration::from_secs(0),
heartbeat_timeout: Duration::from_secs(0),
remote_storage: None,
max_offloader_lag_bytes: 0,
backup_runtime_threads: None,
wal_backup_enabled: false,
auth: None,
};
let mut global = GlobalMap::new(disk, conf.clone())?;
let mut conns: HashMap<i64, ConnState> = HashMap::new();
for (&ttid, shared_state) in global.timelines.iter_mut() {
let flush_lsn = shared_state.sk.wal_store.flush_lsn();
let commit_lsn = shared_state.sk.state.commit_lsn;
os.log_event(format!("tli_loaded;{};{}", flush_lsn.0, commit_lsn.0));
}
let epoll = os.epoll();
loop {
// waiting for the next message
let mut next_event = Some(epoll.recv());
loop {
let event = match next_event {
Some(event) => event,
None => break,
};
match event {
NodeEvent::Accept(tcp) => {
conns.insert(
tcp.id(),
ConnState {
tcp,
greeting: false,
ttid: TenantTimelineId::empty(),
flush_pending: false,
},
);
}
NodeEvent::Message((msg, tcp)) => {
let conn = conns.get_mut(&tcp.id());
if let Some(conn) = conn {
let res = conn.process_any(msg, &mut global);
if res.is_err() {
debug!("conn {:?} error: {:#}", tcp, res.unwrap_err());
conns.remove(&tcp.id());
}
} else {
debug!("conn {:?} was closed, dropping msg {:?}", tcp, msg);
}
}
NodeEvent::Internal(_) => {}
NodeEvent::Closed(_) => {}
NodeEvent::WakeTimeout(_) => {}
}
// TODO: make simulator support multiple events per tick
next_event = epoll.try_recv();
}
conns.retain(|_, conn| {
let res = conn.flush(&mut global);
if res.is_err() {
debug!("conn {:?} error: {:?}", conn.tcp, res);
}
res.is_ok()
});
}
}
impl ConnState {
fn process_any(&mut self, any: AnyMessage, global: &mut GlobalMap) -> Result<()> {
if let AnyMessage::Bytes(copy_data) = any {
let repl_prefix = b"START_REPLICATION ";
if !self.greeting && copy_data.starts_with(repl_prefix) {
self.process_start_replication(copy_data.slice(repl_prefix.len()..), global)?;
bail!("finished processing START_REPLICATION")
}
let msg = ProposerAcceptorMessage::parse(copy_data)?;
debug!("got msg: {:?}", msg);
return self.process(msg, global);
} else {
bail!("unexpected message, expected AnyMessage::Bytes");
}
}
fn process_start_replication(
&mut self,
copy_data: Bytes,
global: &mut GlobalMap,
) -> Result<()> {
// format is "<tenant_id> <timeline_id> <start_lsn> <end_lsn>"
let str = String::from_utf8(copy_data.to_vec())?;
let mut parts = str.split(' ');
let tenant_id = parts.next().unwrap().parse::<TenantId>()?;
let timeline_id = parts.next().unwrap().parse::<TimelineId>()?;
let start_lsn = parts.next().unwrap().parse::<u64>()?;
let end_lsn = parts.next().unwrap().parse::<u64>()?;
let ttid = TenantTimelineId::new(tenant_id, timeline_id);
let shared_state = global.get(&ttid);
// read bytes from start_lsn to end_lsn
let mut buf = vec![0; (end_lsn - start_lsn) as usize];
shared_state.disk.wal.lock().read(start_lsn, &mut buf);
// send bytes to the client
self.tcp.send(AnyMessage::Bytes(Bytes::from(buf)));
Ok(())
}
fn init_timeline(
&mut self,
ttid: TenantTimelineId,
server_info: ServerInfo,
global: &mut GlobalMap,
) -> Result<()> {
self.ttid = ttid;
if global.has_tli(&ttid) {
return Ok(());
}
global.create(ttid, server_info)
}
fn process(&mut self, msg: ProposerAcceptorMessage, global: &mut GlobalMap) -> Result<()> {
if !self.greeting {
self.greeting = true;
match msg {
ProposerAcceptorMessage::Greeting(ref greeting) => {
debug!(
"start handshake with walproposer {:?}",
self.tcp,
);
let server_info = ServerInfo {
pg_version: greeting.pg_version,
system_id: greeting.system_id,
wal_seg_size: greeting.wal_seg_size,
};
let ttid = TenantTimelineId::new(greeting.tenant_id, greeting.timeline_id);
self.init_timeline(ttid, server_info, global)?
}
_ => {
bail!("unexpected message {msg:?} instead of greeting");
}
}
}
let tli = global.get(&self.ttid);
match msg {
ProposerAcceptorMessage::AppendRequest(append_request) => {
self.flush_pending = true;
self.process_sk_msg(
tli,
&ProposerAcceptorMessage::NoFlushAppendRequest(append_request),
)?;
}
other => {
self.process_sk_msg(tli, &other)?;
}
}
Ok(())
}
/// Process FlushWAL if needed.
// TODO: add extra flushes, to verify that extra flushes don't break anything
fn flush(&mut self, global: &mut GlobalMap) -> Result<()> {
if !self.flush_pending {
return Ok(());
}
self.flush_pending = false;
let shared_state = global.get(&self.ttid);
self.process_sk_msg(shared_state, &ProposerAcceptorMessage::FlushWAL)
}
/// Make safekeeper process a message and send a reply to the TCP
fn process_sk_msg(
&mut self,
shared_state: &mut SharedState,
msg: &ProposerAcceptorMessage,
) -> Result<()> {
let mut reply = shared_state.sk.process_msg(msg)?;
if let Some(reply) = &mut reply {
// // if this is AppendResponse, fill in proper hot standby feedback and disk consistent lsn
// if let AcceptorProposerMessage::AppendResponse(ref mut resp) = reply {
// // TODO:
// }
let mut buf = BytesMut::with_capacity(128);
reply.serialize(&mut buf)?;
self.tcp.send(AnyMessage::Bytes(buf.into()));
}
Ok(())
}
}
impl Drop for ConnState {
fn drop(&mut self) {
debug!("dropping conn: {:?}", self.tcp);
if !std::thread::panicking() {
self.tcp.close();
}
// TODO: clean up non-fsynced WAL
}
}

View File

@@ -0,0 +1,38 @@
use std::sync::Arc;
use safekeeper::{
simlib::{
network::{Delay, NetworkOptions},
world::World,
},
simtest::{start_simulation, Options},
};
use crate::{bindings::RunClientC, c_context};
#[test]
fn run_rust_c_test() {
let delay = Delay {
min: 1,
max: 5,
fail_prob: 0.5,
};
let network = NetworkOptions {
keepalive_timeout: Some(50),
connect_delay: delay.clone(),
send_delay: delay.clone(),
};
let u32_data: [u32; 5] = [1, 2, 3, 4, 5];
let world = Arc::new(World::new(1337, Arc::new(network), c_context()));
start_simulation(Options {
world,
time_limit: 1_000_000,
client_fn: Box::new(move |_, server_id| unsafe {
RunClientC(server_id);
}),
u32_data,
});
}

View File

@@ -0,0 +1,234 @@
use std::{ops::Deref, sync::Arc};
use anyhow::Result;
use bytes::{Buf, BytesMut};
use log::{debug, info};
use postgres_ffi::{waldecoder::WalStreamDecoder, XLogSegNo};
use safekeeper::{control_file, safekeeper::SafeKeeperState, wal_storage};
use utils::lsn::Lsn;
use super::disk::TimelineDisk;
pub struct DiskStateStorage {
persisted_state: SafeKeeperState,
disk: Arc<TimelineDisk>,
}
impl DiskStateStorage {
pub fn new(disk: Arc<TimelineDisk>) -> Self {
let guard = disk.state.lock();
let state = guard.clone();
drop(guard);
DiskStateStorage {
persisted_state: state,
disk,
}
}
}
impl control_file::Storage for DiskStateStorage {
fn persist(&mut self, s: &SafeKeeperState) -> Result<()> {
self.persisted_state = s.clone();
*self.disk.state.lock() = s.clone();
Ok(())
}
}
impl Deref for DiskStateStorage {
type Target = SafeKeeperState;
fn deref(&self) -> &Self::Target {
&self.persisted_state
}
}
pub struct DummyWalStore {
lsn: Lsn,
}
impl DummyWalStore {
pub fn new() -> Self {
DummyWalStore { lsn: Lsn::INVALID }
}
}
impl wal_storage::Storage for DummyWalStore {
fn flush_lsn(&self) -> Lsn {
self.lsn
}
fn write_wal(&mut self, startpos: Lsn, buf: &[u8]) -> Result<()> {
self.lsn = startpos + buf.len() as u64;
Ok(())
}
fn truncate_wal(&mut self, end_pos: Lsn) -> Result<()> {
self.lsn = end_pos;
Ok(())
}
fn flush_wal(&mut self) -> Result<()> {
Ok(())
}
fn remove_up_to(&self) -> Box<dyn Fn(XLogSegNo) -> Result<()>> {
Box::new(move |_segno_up_to: XLogSegNo| Ok(()))
}
fn get_metrics(&self) -> safekeeper::metrics::WalStorageMetrics {
safekeeper::metrics::WalStorageMetrics::default()
}
}
pub struct DiskWALStorage {
/// Written to disk, but possibly still in the cache and not fully persisted.
/// Also can be ahead of record_lsn, if happen to be in the middle of a WAL record.
write_lsn: Lsn,
/// The LSN of the last WAL record written to disk. Still can be not fully flushed.
write_record_lsn: Lsn,
/// The LSN of the last WAL record flushed to disk.
flush_record_lsn: Lsn,
/// Decoder is required for detecting boundaries of WAL records.
decoder: WalStreamDecoder,
unflushed_bytes: BytesMut,
disk: Arc<TimelineDisk>,
}
impl DiskWALStorage {
pub fn new(disk: Arc<TimelineDisk>, state: &SafeKeeperState) -> Result<Self> {
let write_lsn = if state.commit_lsn == Lsn(0) {
Lsn(0)
} else {
Self::find_end_of_wal(disk.clone(), state.commit_lsn)?
};
let flush_lsn = write_lsn;
Ok(DiskWALStorage {
write_lsn,
write_record_lsn: flush_lsn,
flush_record_lsn: flush_lsn,
decoder: WalStreamDecoder::new(flush_lsn, 15),
unflushed_bytes: BytesMut::new(),
disk,
})
}
fn find_end_of_wal(disk: Arc<TimelineDisk>, start_lsn: Lsn) -> Result<Lsn> {
let mut buf = [0; 8192];
let mut pos = start_lsn.0;
let mut decoder = WalStreamDecoder::new(start_lsn, 15);
let mut result = start_lsn;
loop {
disk.wal.lock().read(pos, &mut buf);
pos += buf.len() as u64;
decoder.feed_bytes(&buf);
loop {
match decoder.poll_decode() {
Ok(Some(record)) => result = record.0,
Err(e) => {
debug!(
"find_end_of_wal reached end at {:?}, decode error: {:?}",
result, e
);
return Ok(result);
}
Ok(None) => break, // need more data
}
}
}
}
}
impl wal_storage::Storage for DiskWALStorage {
fn flush_lsn(&self) -> Lsn {
self.flush_record_lsn
}
fn write_wal(&mut self, startpos: Lsn, buf: &[u8]) -> Result<()> {
if self.write_lsn != startpos {
panic!("write_wal called with wrong startpos");
}
self.unflushed_bytes.extend_from_slice(buf);
self.write_lsn += buf.len() as u64;
if self.decoder.available() != startpos {
info!(
"restart decoder from {} to {}",
self.decoder.available(),
startpos,
);
self.decoder = WalStreamDecoder::new(startpos, 15);
}
self.decoder.feed_bytes(buf);
loop {
match self.decoder.poll_decode()? {
None => break, // no full record yet
Some((lsn, _rec)) => {
self.write_record_lsn = lsn;
}
}
}
Ok(())
}
fn truncate_wal(&mut self, end_pos: Lsn) -> Result<()> {
if self.write_lsn != Lsn(0) && end_pos > self.write_lsn {
panic!(
"truncate_wal called on non-written WAL, write_lsn={}, end_pos={}",
self.write_lsn, end_pos
);
}
self.flush_wal()?;
// write zeroes to disk from end_pos until self.write_lsn
let buf = [0; 8192];
let mut pos = end_pos.0;
while pos < self.write_lsn.0 {
self.disk.wal.lock().write(pos, &buf);
pos += buf.len() as u64;
}
self.write_lsn = end_pos;
self.write_record_lsn = end_pos;
self.flush_record_lsn = end_pos;
self.unflushed_bytes.clear();
self.decoder = WalStreamDecoder::new(end_pos, 15);
Ok(())
}
fn flush_wal(&mut self) -> Result<()> {
if self.flush_record_lsn == self.write_record_lsn {
// no need to do extra flush
return Ok(());
}
let num_bytes = self.write_record_lsn.0 - self.flush_record_lsn.0;
self.disk.wal.lock().write(
self.flush_record_lsn.0,
&self.unflushed_bytes[..num_bytes as usize],
);
self.unflushed_bytes.advance(num_bytes as usize);
self.flush_record_lsn = self.write_record_lsn;
Ok(())
}
fn remove_up_to(&self) -> Box<dyn Fn(XLogSegNo) -> Result<()>> {
Box::new(move |_segno_up_to: XLogSegNo| Ok(()))
}
fn get_metrics(&self) -> safekeeper::metrics::WalStorageMetrics {
safekeeper::metrics::WalStorageMetrics::default()
}
}

View File

@@ -0,0 +1,610 @@
use std::{ffi::CString, path::Path, str::FromStr, sync::Arc, collections::HashMap};
use rand::{Rng, SeedableRng};
use safekeeper::simlib::{
network::{Delay, NetworkOptions},
proto::AnyMessage,
time::EmptyEvent,
world::World,
world::{Node, NodeEvent, SEvent, NodeId},
};
use tracing::{debug, error, info, warn};
use utils::{id::TenantTimelineId, lsn::Lsn};
use crate::{
bindings::{
neon_tenant_walproposer, neon_timeline_walproposer, sim_redo_start_lsn, syncSafekeepers,
wal_acceptor_connection_timeout, wal_acceptor_reconnect_timeout, wal_acceptors_list,
MyInsertRecord, WalProposerCleanup, WalProposerRust,
},
c_context,
simtest::{
log::{init_logger, SimClock},
safekeeper::run_server,
},
};
use super::disk::Disk;
pub struct SkNode {
pub node: Arc<Node>,
pub id: u32,
pub disk: Arc<Disk>,
}
impl SkNode {
pub fn new(node: Arc<Node>) -> Self {
let disk = Arc::new(Disk::new());
let res = Self {
id: node.id,
node,
disk,
};
res.launch();
res
}
pub fn launch(&self) {
let id = self.id;
let disk = self.disk.clone();
// start the server thread
self.node.launch(move |os| {
let res = run_server(os, disk);
debug!("server {} finished: {:?}", id, res);
});
}
pub fn restart(&self) {
self.node.crash_stop();
self.launch();
}
}
pub struct TestConfig {
pub network: NetworkOptions,
pub timeout: u64,
pub clock: Option<SimClock>,
}
impl TestConfig {
pub fn new(clock: Option<SimClock>) -> Self {
Self {
network: NetworkOptions {
keepalive_timeout: Some(2000),
connect_delay: Delay {
min: 1,
max: 5,
fail_prob: 0.0,
},
send_delay: Delay {
min: 1,
max: 5,
fail_prob: 0.0,
},
},
timeout: 1_000 * 10,
clock,
}
}
pub fn start(&self, seed: u64) -> Test {
let world = Arc::new(World::new(
seed,
Arc::new(self.network.clone()),
c_context(),
));
world.register_world();
if let Some(clock) = &self.clock {
clock.set_world(world.clone());
}
let servers = [
SkNode::new(world.new_node()),
SkNode::new(world.new_node()),
SkNode::new(world.new_node()),
];
let server_ids = [servers[0].id, servers[1].id, servers[2].id];
let safekeepers_guc = server_ids.map(|id| format!("node:{}", id)).join(",");
let ttid = TenantTimelineId::generate();
// wait init for all servers
world.await_all();
// clean up pgdata directory
self.init_pgdata();
Test {
world,
servers,
safekeepers_guc,
ttid,
timeout: self.timeout,
}
}
pub fn init_pgdata(&self) {
let pgdata = Path::new("/home/admin/simulator/libs/walproposer/pgdata");
if pgdata.exists() {
std::fs::remove_dir_all(pgdata).unwrap();
}
std::fs::create_dir(pgdata).unwrap();
// create empty pg_wal and pg_notify subdirs
std::fs::create_dir(pgdata.join("pg_wal")).unwrap();
std::fs::create_dir(pgdata.join("pg_notify")).unwrap();
// write postgresql.conf
let mut conf = std::fs::File::create(pgdata.join("postgresql.conf")).unwrap();
let content = "
wal_log_hints=off
hot_standby=on
fsync=off
wal_level=replica
restart_after_crash=off
shared_preload_libraries=neon
neon.pageserver_connstring=''
neon.tenant_id=cc6e67313d57283bad411600fbf5c142
neon.timeline_id=de6fa815c1e45aa61491c3d34c4eb33e
synchronous_standby_names=walproposer
neon.safekeepers='node:1,node:2,node:3'
max_connections=100
";
std::io::Write::write_all(&mut conf, content.as_bytes()).unwrap();
}
}
pub struct Test {
pub world: Arc<World>,
pub servers: [SkNode; 3],
pub safekeepers_guc: String,
pub ttid: TenantTimelineId,
pub timeout: u64,
}
impl Test {
fn launch_sync(&self) -> Arc<Node> {
let client_node = self.world.new_node();
debug!("sync-safekeepers started at node {}", client_node.id);
// start the client thread
let guc = self.safekeepers_guc.clone();
let ttid = self.ttid.clone();
client_node.launch(move |_| {
let list = CString::new(guc).unwrap();
unsafe {
WalProposerCleanup();
syncSafekeepers = true;
wal_acceptors_list = list.into_raw();
wal_acceptor_reconnect_timeout = 1000;
wal_acceptor_connection_timeout = 5000;
neon_tenant_walproposer =
CString::new(ttid.tenant_id.to_string()).unwrap().into_raw();
neon_timeline_walproposer = CString::new(ttid.timeline_id.to_string())
.unwrap()
.into_raw();
WalProposerRust();
}
});
self.world.await_all();
client_node
}
pub fn sync_safekeepers(&self) -> anyhow::Result<Lsn> {
let client_node = self.launch_sync();
// poll until exit or timeout
let time_limit = self.timeout;
while self.world.step() && self.world.now() < time_limit && !client_node.is_finished() {}
if !client_node.is_finished() {
anyhow::bail!("timeout or idle stuck");
}
let res = client_node.result.lock().clone();
if res.0 != 0 {
anyhow::bail!("non-zero exitcode: {:?}", res);
}
let lsn = Lsn::from_str(&res.1)?;
Ok(lsn)
}
pub fn launch_walproposer(&self, lsn: Lsn) -> WalProposer {
let client_node = self.world.new_node();
let lsn = if lsn.0 == 0 {
// usual LSN after basebackup
Lsn(21623024)
} else {
lsn
};
// start the client thread
let guc = self.safekeepers_guc.clone();
let ttid = self.ttid.clone();
client_node.launch(move |_| {
let list = CString::new(guc).unwrap();
unsafe {
WalProposerCleanup();
sim_redo_start_lsn = lsn.0;
syncSafekeepers = false;
wal_acceptors_list = list.into_raw();
wal_acceptor_reconnect_timeout = 1000;
wal_acceptor_connection_timeout = 5000;
neon_tenant_walproposer =
CString::new(ttid.tenant_id.to_string()).unwrap().into_raw();
neon_timeline_walproposer = CString::new(ttid.timeline_id.to_string())
.unwrap()
.into_raw();
WalProposerRust();
}
});
self.world.await_all();
WalProposer {
node: client_node,
}
}
pub fn poll_for_duration(&self, duration: u64) {
let time_limit = std::cmp::min(self.world.now() + duration, self.timeout);
while self.world.step() && self.world.now() < time_limit {}
}
pub fn run_schedule(&self, schedule: &Schedule) -> anyhow::Result<()> {
{
let empty_event = Box::new(EmptyEvent);
let now = self.world.now();
for (time, _) in schedule {
if *time < now {
continue;
}
self.world.schedule(*time - now, empty_event.clone())
}
}
let mut wait_node = self.launch_sync();
// fake walproposer
let mut wp = WalProposer {
node: wait_node.clone(),
};
let mut sync_in_progress = true;
let mut skipped_tx = 0;
let mut started_tx = 0;
let mut schedule_ptr = 0;
loop {
if sync_in_progress && wait_node.is_finished() {
let res = wait_node.result.lock().clone();
if res.0 != 0 {
warn!("sync non-zero exitcode: {:?}", res);
debug!("restarting walproposer");
wait_node = self.launch_sync();
continue;
}
let lsn = Lsn::from_str(&res.1)?;
debug!("sync-safekeepers finished at LSN {}", lsn);
wp = self.launch_walproposer(lsn);
wait_node = wp.node.clone();
debug!("walproposer started at node {}", wait_node.id);
sync_in_progress = false;
}
let now = self.world.now();
while schedule_ptr < schedule.len() && schedule[schedule_ptr].0 <= now {
if now != schedule[schedule_ptr].0 {
warn!("skipped event {:?} at {}", schedule[schedule_ptr], now);
}
let action = &schedule[schedule_ptr].1;
match action {
TestAction::WriteTx(size) => {
if !sync_in_progress && !wait_node.is_finished() {
started_tx += *size;
wp.write_tx(*size);
debug!("written {} transactions", size);
} else {
skipped_tx += size;
debug!("skipped {} transactions", size);
}
}
TestAction::RestartSafekeeper(id) => {
debug!("restarting safekeeper {}", id);
self.servers[*id as usize].restart();
}
TestAction::RestartWalProposer => {
debug!("restarting walproposer");
wait_node.crash_stop();
sync_in_progress = true;
wait_node = self.launch_sync();
}
}
schedule_ptr += 1;
}
if schedule_ptr == schedule.len() {
break;
}
let next_event_time = schedule[schedule_ptr].0;
// poll until the next event
if wait_node.is_finished() {
while self.world.step() && self.world.now() < next_event_time {}
} else {
while self.world.step()
&& self.world.now() < next_event_time
&& !wait_node.is_finished()
{}
}
}
debug!("finished schedule");
debug!("skipped_tx: {}", skipped_tx);
debug!("started_tx: {}", started_tx);
Ok(())
}
}
pub struct WalProposer {
pub node: Arc<Node>,
}
impl WalProposer {
pub fn write_tx(&mut self, cnt: usize) {
self.node
.network_chan()
.send(NodeEvent::Internal(AnyMessage::Just32(cnt as u32)));
}
pub fn stop(&self) {
self.node.crash_stop();
}
}
#[derive(Debug, Clone)]
pub enum TestAction {
WriteTx(usize),
RestartSafekeeper(usize),
RestartWalProposer,
}
pub type Schedule = Vec<(u64, TestAction)>;
pub fn generate_schedule(seed: u64) -> Schedule {
let mut rng = rand::rngs::StdRng::seed_from_u64(seed);
let mut schedule = Vec::new();
let mut time = 0;
let cnt = rng.gen_range(1..100);
for _ in 0..cnt {
time += rng.gen_range(0..500);
let action = match rng.gen_range(0..3) {
0 => TestAction::WriteTx(rng.gen_range(1..10)),
1 => TestAction::RestartSafekeeper(rng.gen_range(0..3)),
2 => TestAction::RestartWalProposer,
_ => unreachable!(),
};
schedule.push((time, action));
}
schedule
}
pub fn generate_network_opts(seed: u64) -> NetworkOptions {
let mut rng = rand::rngs::StdRng::seed_from_u64(seed);
let timeout = rng.gen_range(100..2000);
let max_delay = rng.gen_range(1..2*timeout);
let min_delay = rng.gen_range(1..=max_delay);
let max_fail_prob = rng.gen_range(0.0..0.9);
let connect_fail_prob = rng.gen_range(0.0..max_fail_prob);
let send_fail_prob = rng.gen_range(0.0..connect_fail_prob);
NetworkOptions {
keepalive_timeout: Some(timeout),
connect_delay: Delay {
min: min_delay,
max: max_delay,
fail_prob: connect_fail_prob,
},
send_delay: Delay {
min: min_delay,
max: max_delay,
fail_prob: send_fail_prob,
},
}
}
#[derive(Debug,Clone,PartialEq,Eq)]
enum NodeKind {
Unknown,
Safekeeper,
WalProposer,
}
impl Default for NodeKind {
fn default() -> Self {
Self::Unknown
}
}
#[derive(Clone, Debug, Default)]
struct NodeInfo {
kind: NodeKind,
// walproposer
is_sync: bool,
term: u64,
epoch_lsn: u64,
// safekeeper
commit_lsn: u64,
flush_lsn: u64,
}
impl NodeInfo {
fn init_kind(&mut self, kind: NodeKind) {
if self.kind == NodeKind::Unknown {
self.kind = kind;
} else {
assert!(self.kind == kind);
}
}
fn started(&mut self, data: &str) {
let mut parts = data.split(';');
assert!(parts.next().unwrap() == "started");
match parts.next().unwrap() {
"safekeeper" => {
self.init_kind(NodeKind::Safekeeper);
}
"walproposer" => {
self.init_kind(NodeKind::WalProposer);
let is_sync: u8 = parts.next().unwrap().parse().unwrap();
self.is_sync = is_sync != 0;
}
_ => unreachable!(),
}
}
}
#[derive(Debug,Default)]
struct GlobalState {
nodes: Vec<NodeInfo>,
commit_lsn: u64,
write_lsn: u64,
max_write_lsn: u64,
written_wal: u64,
written_records: u64,
}
impl GlobalState {
fn new() -> Self {
Default::default()
}
fn get(&mut self, id: u32) -> &mut NodeInfo {
let id = id as usize;
if id >= self.nodes.len() {
self.nodes.resize(id + 1, NodeInfo::default());
}
&mut self.nodes[id]
}
}
pub fn validate_events(events: Vec<SEvent>) {
const INITDB_LSN: u64 = 21623024;
let hook = std::panic::take_hook();
scopeguard::defer_on_success! {
std::panic::set_hook(hook);
};
let mut state = GlobalState::new();
state.max_write_lsn = INITDB_LSN;
for event in events {
debug!("{:?}", event);
let node = state.get(event.node);
if event.data.starts_with("started;") {
node.started(&event.data);
continue;
}
assert!(node.kind != NodeKind::Unknown);
// drop reference to unlock state
let mut node = node.clone();
let mut parts = event.data.split(';');
match node.kind {
NodeKind::Safekeeper => {
match parts.next().unwrap() {
"tli_loaded" => {
let flush_lsn: u64 = parts.next().unwrap().parse().unwrap();
let commit_lsn: u64 = parts.next().unwrap().parse().unwrap();
node.flush_lsn = flush_lsn;
node.commit_lsn = commit_lsn;
}
_ => unreachable!(),
}
}
NodeKind::WalProposer => {
match parts.next().unwrap() {
"prop_elected" => {
let prop_lsn: u64 = parts.next().unwrap().parse().unwrap();
let prop_term: u64 = parts.next().unwrap().parse().unwrap();
let prev_lsn: u64 = parts.next().unwrap().parse().unwrap();
let prev_term: u64 = parts.next().unwrap().parse().unwrap();
assert!(prop_lsn >= prev_lsn);
assert!(prop_term >= prev_term);
assert!(prop_lsn >= state.commit_lsn);
if prop_lsn > state.write_lsn {
assert!(prop_lsn <= state.max_write_lsn);
debug!("moving write_lsn up from {} to {}", state.write_lsn, prop_lsn);
state.write_lsn = prop_lsn;
}
if prop_lsn < state.write_lsn {
debug!("moving write_lsn down from {} to {}", state.write_lsn, prop_lsn);
state.write_lsn = prop_lsn;
}
node.epoch_lsn = prop_lsn;
node.term = prop_term;
}
"write_wal" => {
assert!(!node.is_sync);
let start_lsn: u64 = parts.next().unwrap().parse().unwrap();
let end_lsn: u64 = parts.next().unwrap().parse().unwrap();
let cnt: u64 = parts.next().unwrap().parse().unwrap();
let size = end_lsn - start_lsn;
state.written_wal += size;
state.written_records += cnt;
// TODO: If we allow writing WAL before winning the election
assert!(start_lsn >= state.commit_lsn);
assert!(end_lsn >= start_lsn);
assert!(start_lsn == state.write_lsn);
state.write_lsn = end_lsn;
if end_lsn > state.max_write_lsn {
state.max_write_lsn = end_lsn;
}
}
"commit_lsn" => {
let lsn: u64 = parts.next().unwrap().parse().unwrap();
assert!(lsn >= state.commit_lsn);
state.commit_lsn = lsn;
}
_ => unreachable!(),
}
}
_ => unreachable!(),
}
// update the node in the state struct
*state.get(event.node) = node;
}
}

View File

@@ -0,0 +1,265 @@
use std::{ffi::CString, path::Path, str::FromStr, sync::Arc};
use rand::Rng;
use safekeeper::simlib::{
network::{Delay, NetworkOptions},
proto::AnyMessage,
world::World,
world::{Node, NodeEvent},
};
use tracing::{info, warn};
use utils::{id::TenantTimelineId, lsn::Lsn};
use crate::{
bindings::{
neon_tenant_walproposer, neon_timeline_walproposer, sim_redo_start_lsn, syncSafekeepers,
wal_acceptor_connection_timeout, wal_acceptor_reconnect_timeout, wal_acceptors_list,
MyInsertRecord, WalProposerCleanup, WalProposerRust,
},
c_context,
simtest::{
log::{init_logger, SimClock},
safekeeper::run_server,
util::{generate_schedule, TestConfig, generate_network_opts, validate_events},
}, enable_debug,
};
use super::{
disk::Disk,
util::{Schedule, TestAction},
};
#[test]
fn sync_empty_safekeepers() {
let clock = init_logger();
let mut config = TestConfig::new(Some(clock));
let test = config.start(1337);
let lsn = test.sync_safekeepers().unwrap();
assert_eq!(lsn, Lsn(0));
info!("Sucessfully synced empty safekeepers at 0/0");
let lsn = test.sync_safekeepers().unwrap();
assert_eq!(lsn, Lsn(0));
info!("Sucessfully synced (again) empty safekeepers at 0/0");
}
#[test]
fn run_walproposer_generate_wal() {
let clock = init_logger();
let mut config = TestConfig::new(Some(clock));
// config.network.timeout = Some(250);
let test = config.start(1337);
let lsn = test.sync_safekeepers().unwrap();
assert_eq!(lsn, Lsn(0));
info!("Sucessfully synced empty safekeepers at 0/0");
let mut wp = test.launch_walproposer(lsn);
test.poll_for_duration(30);
for i in 0..100 {
wp.write_tx(1);
test.poll_for_duration(5);
}
}
#[test]
fn crash_safekeeper() {
let clock = init_logger();
let mut config = TestConfig::new(Some(clock));
// config.network.timeout = Some(250);
let test = config.start(1337);
let lsn = test.sync_safekeepers().unwrap();
assert_eq!(lsn, Lsn(0));
info!("Sucessfully synced empty safekeepers at 0/0");
let mut wp = test.launch_walproposer(lsn);
test.poll_for_duration(30);
wp.write_tx(3);
test.servers[0].restart();
test.poll_for_duration(100);
test.poll_for_duration(1000);
}
#[test]
fn test_simple_restart() {
let clock = init_logger();
let mut config = TestConfig::new(Some(clock));
// config.network.timeout = Some(250);
let test = config.start(1337);
let lsn = test.sync_safekeepers().unwrap();
assert_eq!(lsn, Lsn(0));
info!("Sucessfully synced empty safekeepers at 0/0");
let mut wp = test.launch_walproposer(lsn);
test.poll_for_duration(30);
wp.write_tx(3);
test.poll_for_duration(100);
wp.stop();
drop(wp);
let lsn = test.sync_safekeepers().unwrap();
info!("Sucessfully synced safekeepers at {}", lsn);
}
#[test]
fn test_simple_schedule() -> anyhow::Result<()> {
let clock = init_logger();
let mut config = TestConfig::new(Some(clock));
config.network.keepalive_timeout = Some(100);
let test = config.start(1337);
let schedule: Schedule = vec![
(0, TestAction::RestartWalProposer),
(50, TestAction::WriteTx(5)),
(100, TestAction::RestartSafekeeper(0)),
(100, TestAction::WriteTx(5)),
(110, TestAction::RestartSafekeeper(1)),
(110, TestAction::WriteTx(5)),
(120, TestAction::RestartSafekeeper(2)),
(120, TestAction::WriteTx(5)),
(201, TestAction::RestartWalProposer),
(251, TestAction::RestartSafekeeper(0)),
(251, TestAction::RestartSafekeeper(1)),
(251, TestAction::RestartSafekeeper(2)),
(251, TestAction::WriteTx(5)),
(255, TestAction::WriteTx(5)),
(1000, TestAction::WriteTx(5)),
];
test.run_schedule(&schedule)?;
info!("Test finished, stopping all threads");
test.world.deallocate();
Ok(())
}
#[test]
fn test_many_tx() -> anyhow::Result<()> {
enable_debug();
let clock = init_logger();
let mut config = TestConfig::new(Some(clock));
let test = config.start(1337);
let mut schedule: Schedule = vec![];
for i in 0..100 {
schedule.push((i * 10, TestAction::WriteTx(10)));
}
test.run_schedule(&schedule)?;
info!("Test finished, stopping all threads");
test.world.stop_all();
let events = test.world.take_events();
info!("Events: {:?}", events);
let last_commit_lsn = events
.iter()
.filter_map(|event| {
if event.data.starts_with("commit_lsn;") {
let lsn: u64 = event.data.split(';').nth(1).unwrap().parse().unwrap();
return Some(lsn);
}
None
})
.last()
.unwrap();
let initdb_lsn = 21623024;
let diff = last_commit_lsn - initdb_lsn;
info!("Last commit lsn: {}, diff: {}", last_commit_lsn, diff);
assert!(diff > 1000 * 8);
Ok(())
}
#[test]
fn test_random_schedules() -> anyhow::Result<()> {
let clock = init_logger();
let mut config = TestConfig::new(Some(clock));
config.network.keepalive_timeout = Some(100);
for i in 0..30000 {
let seed: u64 = rand::thread_rng().gen();
config.network = generate_network_opts(seed);
let test = config.start(seed);
warn!("Running test with seed {}", seed);
let schedule = generate_schedule(seed);
test.run_schedule(&schedule).unwrap();
validate_events(test.world.take_events());
test.world.deallocate();
}
Ok(())
}
#[test]
fn test_one_schedule() -> anyhow::Result<()> {
enable_debug();
let clock = init_logger();
let mut config = TestConfig::new(Some(clock));
config.network.keepalive_timeout = Some(100);
// let seed = 6762900106769428342;
// let test = config.start(seed);
// warn!("Running test with seed {}", seed);
// let schedule = generate_schedule(seed);
// info!("schedule: {:?}", schedule);
// test.run_schedule(&schedule)?;
// test.world.deallocate();
let seed = 3649773280641776194;
config.network = generate_network_opts(seed);
info!("network: {:?}", config.network);
let test = config.start(seed);
warn!("Running test with seed {}", seed);
let schedule = generate_schedule(seed);
info!("schedule: {:?}", schedule);
test.run_schedule(&schedule).unwrap();
validate_events(test.world.take_events());
test.world.deallocate();
Ok(())
}
#[test]
fn test_res_dealloc() -> anyhow::Result<()> {
// enable_debug();
let clock = init_logger();
let mut config = TestConfig::new(Some(clock));
// print pid
let pid = unsafe { libc::getpid() };
info!("pid: {}", pid);
let seed = 123456;
config.network = generate_network_opts(seed);
let test = config.start(seed);
warn!("Running test with seed {}", seed);
let schedule = generate_schedule(seed);
info!("schedule: {:?}", schedule);
test.run_schedule(&schedule).unwrap();
test.world.stop_all();
let world = test.world.clone();
drop(test);
info!("world strong count: {}", Arc::strong_count(&world));
world.deallocate();
info!("world strong count: {}", Arc::strong_count(&world));
Ok(())
}

View File

@@ -0,0 +1,31 @@
use tracing::info;
use crate::bindings::{TestFunc, MyContextInit};
#[test]
fn test_rust_c_calls() {
let res = std::thread::spawn(|| {
let res = unsafe {
MyContextInit();
TestFunc(1, 2)
};
res
}).join().unwrap();
info!("res: {}", res);
}
#[test]
fn test_sim_bindings() {
std::thread::spawn(|| {
unsafe {
MyContextInit();
TestFunc(1, 2)
}
}).join().unwrap();
std::thread::spawn(|| {
unsafe {
MyContextInit();
TestFunc(1, 2)
}
}).join().unwrap();
}

100
libs/walproposer/test.c Normal file
View File

@@ -0,0 +1,100 @@
#include "bindgen_deps.h"
#include "rust_bindings.h"
#include <stdio.h>
#include <pthread.h>
#include <stdlib.h>
#include "postgres.h"
#include "utils/memutils.h"
#include "utils/guc.h"
#include "miscadmin.h"
#include "common/pg_prng.h"
// From src/backend/main/main.c
const char *progname = "fakepostgres";
int TestFunc(int a, int b) {
printf("TestFunc: %d + %d = %d\n", a, b, a + b);
rust_function(0);
elog(LOG, "postgres elog test");
printf("After rust_function\n");
return a + b;
}
// This is a quick experiment with rewriting existing Rust code in C.
void RunClientC(uint32_t serverId) {
uint32_t clientId = sim_id();
elog(LOG, "started client");
int data_len = 5;
int delivered = 0;
int tcp = sim_open_tcp(serverId);
while (delivered < data_len) {
sim_msg_set_repl_cell(delivered+1, clientId, delivered);
sim_tcp_send(tcp);
Event event = sim_epoll_rcv(-1);
switch (event.tag)
{
case Closed:
elog(LOG, "connection closed");
tcp = sim_open_tcp(serverId);
break;
case Message:
Assert(event.any_message == Just32);
uint32_t val;
sim_msg_get_just_u32(&val);
if (val == delivered + 1) {
delivered += 1;
}
break;
default:
Assert(false);
}
}
}
bool debug_enabled = false;
bool initializedMemoryContext = false;
// pthread_mutex_init(&lock, NULL)?
pthread_mutex_t lock;
void MyContextInit() {
// initializes global variables, TODO how to make them thread-local?
pthread_mutex_lock(&lock);
if (!initializedMemoryContext) {
initializedMemoryContext = true;
MemoryContextInit();
pg_prng_seed(&pg_global_prng_state, 0);
setenv("PGDATA", "/home/admin/simulator/libs/walproposer/pgdata", 1);
/*
* Set default values for command-line options.
*/
InitializeGUCOptions();
/* Acquire configuration parameters */
if (!SelectConfigFiles(NULL, progname))
exit(1);
if (debug_enabled) {
log_min_messages = LOG;
} else {
log_min_messages = FATAL;
}
Log_line_prefix = "[%p] ";
InitializeMaxBackends();
ChangeToDataDir();
CreateSharedMemoryAndSemaphores();
SetInstallXLogFileSegmentActive();
// CreateAuxProcessResourceOwner();
// StartupXLOG();
}
pthread_mutex_unlock(&lock);
}

View File

@@ -37,7 +37,6 @@ num-traits.workspace = true
once_cell.workspace = true
pin-project-lite.workspace = true
postgres.workspace = true
postgres_backend.workspace = true
postgres-protocol.workspace = true
postgres-types.workspace = true
rand.workspace = true

View File

@@ -1,541 +0,0 @@
use anyhow::Result;
use pageserver::repository::Key;
use serde::{Deserialize, Serialize};
use std::cmp::Ordering;
use std::io::{self, BufRead};
use std::{
collections::{BTreeMap, BTreeSet, HashMap},
fmt::Write,
ops::Range,
};
use svg_fmt::{rgb, BeginSvg, EndSvg, Fill, Stroke, Style};
use utils::{lsn::Lsn, project_git_version};
project_git_version!(GIT_VERSION);
// Map values to their compressed coordinate - the index the value
// would have in a sorted and deduplicated list of all values.
struct CoordinateMap<T: Ord + Copy> {
map: BTreeMap<T, usize>,
stretch: f32
}
impl<T: Ord + Copy> CoordinateMap<T> {
fn new(coords: Vec<T>, stretch: f32) -> Self {
let set: BTreeSet<T> = coords.into_iter().collect();
let mut map: BTreeMap<T, usize> = BTreeMap::new();
for (i, e) in set.iter().enumerate() {
map.insert(*e, i);
}
Self { map, stretch }
}
fn map(&self, val: T) -> f32 {
*self.map.get(&val).unwrap() as f32 * self.stretch
}
fn max(&self) -> f32 {
self.map.len() as f32 * self.stretch
}
}
fn parse_filename(name: &str) -> (Range<Key>, Range<Lsn>) {
let split: Vec<&str> = name.split("__").collect();
let keys: Vec<&str> = split[0].split('-').collect();
let mut lsns: Vec<&str> = split[1].split('-').collect();
if lsns.len() == 1 {
lsns.push(lsns[0]);
}
let keys = Key::from_hex(keys[0]).unwrap()..Key::from_hex(keys[1]).unwrap();
let lsns = Lsn::from_hex(lsns[0]).unwrap()..Lsn::from_hex(lsns[1]).unwrap();
(keys, lsns)
}
#[derive(Serialize, Deserialize, PartialEq)]
enum LayerTraceOp {
#[serde(rename = "evict")]
Evict,
#[serde(rename = "flush")]
Flush,
#[serde(rename = "compact_create")]
CompactCreate,
#[serde(rename = "compact_delete")]
CompactDelete,
#[serde(rename = "image_create")]
ImageCreate,
#[serde(rename = "gc_delete")]
GcDelete,
#[serde(rename = "gc_start")]
GcStart,
}
impl std::fmt::Display for LayerTraceOp {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> Result<(), std::fmt::Error> {
let op_str = match self {
LayerTraceOp::Evict => "evict",
LayerTraceOp::Flush => "flush",
LayerTraceOp::CompactCreate => "compact_create",
LayerTraceOp::CompactDelete => "compact_delete",
LayerTraceOp::ImageCreate => "image_create",
LayerTraceOp::GcDelete => "gc_delete",
LayerTraceOp::GcStart => "gc_start",
};
f.write_str(op_str)
}
}
#[serde_with::serde_as]
#[derive(Serialize, Deserialize)]
struct LayerTraceLine {
time: u64,
op: LayerTraceOp,
#[serde(default)]
filename: String,
#[serde_as(as = "Option<serde_with::DisplayFromStr>")]
cutoff: Option<Lsn>,
}
struct LayerTraceFile {
filename: String,
key_range: Range<Key>,
lsn_range: Range<Lsn>,
}
impl LayerTraceFile {
fn is_image(&self) -> bool {
self.lsn_range.start == self.lsn_range.end
}
}
struct LayerTraceEvent {
time_rel: u64,
op: LayerTraceOp,
filename: String,
}
struct GcEvent {
time_rel: u64,
cutoff: Lsn,
}
fn main() -> Result<()> {
// Parse trace lines from stdin
let stdin = io::stdin();
let mut files: HashMap<String, LayerTraceFile> = HashMap::new();
let mut layer_events: Vec<LayerTraceEvent> = Vec::new();
let mut gc_events: Vec<GcEvent> = Vec::new();
let mut first_time: Option<u64> = None;
for line in stdin.lock().lines() {
let line = line.unwrap();
let parsed_line: LayerTraceLine = serde_json::from_str(&line)?;
let time_rel = if let Some(first_time) = first_time {
parsed_line.time - first_time
} else {
first_time = Some(parsed_line.time);
0
};
if parsed_line.op == LayerTraceOp::GcStart {
gc_events.push(GcEvent {
time_rel,
cutoff: parsed_line.cutoff.unwrap(),
});
} else {
layer_events.push(LayerTraceEvent {
time_rel,
filename: parsed_line.filename.clone(),
op: parsed_line.op,
});
if !files.contains_key(&parsed_line.filename) {
let (key_range, lsn_range) = parse_filename(&parsed_line.filename);
files.insert(parsed_line.filename.clone(), LayerTraceFile {
filename: parsed_line.filename.clone(),
key_range,
lsn_range,
});
};
}
}
let mut last_time_rel = layer_events.last().unwrap().time_rel;
if let Some(last_gc) = gc_events.last() {
last_time_rel = std::cmp::min(last_gc.time_rel, last_time_rel);
}
// Collect all coordinates
let mut keys: Vec<Key> = vec![];
let mut lsns: Vec<Lsn> = vec![];
for f in files.values() {
keys.push(f.key_range.start);
keys.push(f.key_range.end);
lsns.push(f.lsn_range.start);
lsns.push(f.lsn_range.end);
}
for gc_event in &gc_events {
lsns.push(gc_event.cutoff);
}
// Analyze
let key_map = CoordinateMap::new(keys, 2.0);
// Stretch out vertically for better visibility
let lsn_map = CoordinateMap::new(lsns, 3.0);
// Initialize stats
let mut num_deltas = 0;
let mut num_images = 0;
let mut svg = String::new();
// Draw
writeln!(svg,
"{}",
BeginSvg {
w: key_map.max(),
h: lsn_map.max(),
}
)?;
let lsn_max = lsn_map.max();
// Sort the files by LSN, but so that image layers go after all delta layers
// The SVG is painted in the order the elements appear, and we want to draw
// image layers on top of the delta layers if they overlap
let mut files_sorted: Vec<LayerTraceFile> = files.into_values().collect();
files_sorted.sort_by(|a, b| {
if a.is_image() && !b.is_image() {
Ordering::Greater
} else if !a.is_image() && b.is_image() {
Ordering::Less
} else {
a.lsn_range.end.cmp(&b.lsn_range.end)
}
});
for f in files_sorted {
let key_start = key_map.map(f.key_range.start);
let key_end = key_map.map(f.key_range.end);
let key_diff = key_end - key_start;
if key_start >= key_end {
panic!("Invalid key range {}-{}", key_start, key_end);
}
let lsn_start = lsn_map.map(f.lsn_range.start);
let lsn_end = lsn_map.map(f.lsn_range.end);
// Fill in and thicken rectangle if it's an
// image layer so that we can see it.
let mut style = Style::default();
style.fill = Fill::Color(rgb(0x80, 0x80, 0x80));
style.stroke = Stroke::Color(rgb(0, 0, 0), 0.5);
let y_start = (lsn_max - lsn_start) as f32;
let y_end = (lsn_max - lsn_end) as f32;
let x_margin = 0.25;
let y_margin = 0.5;
match f.lsn_range.start.cmp(&f.lsn_range.end) {
Ordering::Less => {
num_deltas += 1;
write!(svg,
r#" <rect id="layer_{}" x="{}" y="{}" width="{}" height="{}" ry="{}" style="{}">"#,
f.filename,
key_start as f32 + x_margin,
y_end + y_margin,
key_diff as f32 - x_margin * 2.0,
y_start - y_end - y_margin * 2.0,
1.0, // border_radius,
style.to_string(),
)?;
write!(svg, "<title>{}<br>{} - {}</title>", f.filename, lsn_end, y_end)?;
writeln!(svg, "</rect>")?;
}
Ordering::Equal => {
num_images += 1;
//lsn_diff = 0.3;
//lsn_offset = -lsn_diff / 2.0;
//margin = 0.05;
style.fill = Fill::Color(rgb(0x80, 0, 0x80));
style.stroke = Stroke::Color(rgb(0x80, 0, 0x80), 3.0);
write!(svg,
r#" <line id="layer_{}" x1="{}" y1="{}" x2="{}" y2="{}" style="{}">"#,
f.filename,
key_start as f32 + x_margin,
y_end,
key_end as f32 - x_margin,
y_end,
style.to_string(),
)?;
write!(svg, "<title>{}<br>{} - {}</title>", f.filename, lsn_end, y_end)?;
writeln!(svg, "</line>")?;
}
Ordering::Greater => panic!("Invalid lsn range {}-{}", lsn_start, lsn_end),
}
}
for (idx, gc) in gc_events.iter().enumerate() {
let cutoff_lsn = lsn_map.map(gc.cutoff);
let mut style = Style::default();
style.fill = Fill::None;
style.stroke = Stroke::Color(rgb(0xff, 0, 0), 0.5);
let y = lsn_max - cutoff_lsn;
writeln!(svg,
r#" <line id="gc_{}" x1="{}" y1="{}" x2="{}" y2="{}" style="{}" />"#,
idx,
0,
y,
key_map.max(),
y,
style.to_string(),
)?;
}
writeln!(svg, "{}", EndSvg)?;
let mut layer_events_str = String::new();
let mut first = true;
for e in layer_events {
if !first {
writeln!(layer_events_str, ",")?;
}
write!(layer_events_str,
r#" {{"time_rel": {}, "filename": "{}", "op": "{}"}}"#,
e.time_rel, e.filename, e.op)?;
first = false;
}
writeln!(layer_events_str)?;
let mut gc_events_str = String::new();
let mut first = true;
for e in gc_events {
if !first {
writeln!(gc_events_str, ",")?;
}
write!(gc_events_str,
r#" {{"time_rel": {}, "cutoff_lsn": "{}"}}"#,
e.time_rel, e.cutoff)?;
first = false;
}
writeln!(gc_events_str)?;
println!(r#"<!DOCTYPE html>
<html>
<head>
<style>
/* Keep the slider pinned at top */
.topbar {{
display: block;
overflow: hidden;
background-color: lightgrey;
position: fixed;
top: 0;
width: 100%;
/* width: 500px; */
}}
.slidercontainer {{
float: left;
width: 50%;
margin-right: 200px;
}}
.slider {{
float: left;
width: 100%;
}}
.legend {{
width: 200px;
float: right;
}}
/* Main content */
.main {{
margin-top: 50px; /* Add a top margin to avoid content overlay */
}}
</style>
</head>
<body onload="init()">
<script type="text/javascript">
var layer_events = [{layer_events_str}]
var gc_events = [{gc_events_str}]
let ticker;
function init() {{
moveSlider({last_time_rel})
moveSlider(0)
moveSlider(last_slider_pos)
}}
function startAnimation() {{
ticker = setInterval(animateStep, 100);
}}
function stopAnimation() {{
clearInterval(ticker);
}}
function animateStep() {{
if (last_layer_event < layer_events.length - 1) {{
var slider = document.getElementById("time-slider");
let prevPos = slider.value
let nextEvent = last_layer_event
while (nextEvent < layer_events.length - 1) {{
if (layer_events[nextEvent].time_rel > prevPos) {{
break;
}}
nextEvent += 1;
}}
let nextPos = layer_events[nextEvent].time_rel
slider.value = nextPos
moveSlider(nextPos)
}}
}}
function redoLayerEvent(n, dir) {{
var layer = document.getElementById("layer_" + layer_events[n].filename);
switch (layer_events[n].op) {{
case "evict":
break;
case "flush":
layer.style.visibility = "visible";
break;
case "compact_create":
layer.style.visibility = "visible";
break;
case "image_create":
layer.style.visibility = "visible";
break;
case "compact_delete":
layer.style.visibility = "hidden";
break;
case "gc_delete":
layer.style.visibility = "hidden";
break;
case "gc_start":
layer.style.visibility = "hidden";
break;
}}
}}
function undoLayerEvent(n) {{
var layer = document.getElementById("layer_" + layer_events[n].filename);
switch (layer_events[n].op) {{
case "evict":
break;
case "flush":
layer.style.visibility = "hidden";
break;
case "compact_create":
layer.style.visibility = "hidden";
break;
case "image_create":
layer.style.visibility = "hidden";
break;
case "compact_delete":
layer.style.visibility = "visible";
break;
case "gc_delete":
layer.style.visibility = "visible";
break;
}}
}}
function redoGcEvent(n) {{
var prev_gc_bar = document.getElementById("gc_" + (n - 1));
var new_gc_bar = document.getElementById("gc_" + n);
prev_gc_bar.style.visibility = "hidden"
new_gc_bar.style.visibility = "visible"
}}
function undoGcEvent(n) {{
var prev_gc_bar = document.getElementById("gc_" + n);
var new_gc_bar = document.getElementById("gc_" + (n - 1));
prev_gc_bar.style.visibility = "hidden"
new_gc_bar.style.visibility = "visible"
}}
var last_slider_pos = 0
var last_layer_event = 0
var last_gc_event = 0
var moveSlider = function(new_pos) {{
if (new_pos > last_slider_pos) {{
while (last_layer_event < layer_events.length - 1) {{
if (layer_events[last_layer_event + 1].time_rel > new_pos) {{
break;
}}
last_layer_event += 1;
redoLayerEvent(last_layer_event)
}}
while (last_gc_event < gc_events.length - 1) {{
if (gc_events[last_gc_event + 1].time_rel > new_pos) {{
break;
}}
last_gc_event += 1;
redoGcEvent(last_gc_event)
}}
}}
if (new_pos < last_slider_pos) {{
while (last_layer_event > 0) {{
if (layer_events[last_layer_event - 1].time_rel < new_pos) {{
break;
}}
undoLayerEvent(last_layer_event)
last_layer_event -= 1;
}}
while (last_gc_event > 0) {{
if (gc_events[last_gc_event - 1].time_rel < new_pos) {{
break;
}}
undoGcEvent(last_gc_event)
last_gc_event -= 1;
}}
}}
last_slider_pos = new_pos;
document.getElementById("debug_pos").textContent=new_pos;
document.getElementById("debug_layer_event").textContent=last_layer_event + " " + layer_events[last_layer_event].time_rel + " " + layer_events[last_layer_event].op;
document.getElementById("debug_gc_event").textContent=last_gc_event + " " + gc_events[last_gc_event].time_rel;
}}
</script>
<div class="topbar">
<div class="slidercontainer">
<label for="time-slider">TIME</label>:
<input id="time-slider" class="slider" type="range" min="0" max="{last_time_rel}" value="0" oninput="moveSlider(this.value)"><br>
pos: <span id="debug_pos"></span><br>
event: <span id="debug_layer_event"></span><br>
gc: <span id="debug_gc_event"></span><br>
</div>
<button onclick="startAnimation()">Play</button>
<button onclick="stopAnimation()">Stop</button>
<svg class="legend">
<rect x=5 y=0 width=20 height=20 style="fill:rgb(128,128,128);stroke:rgb(0,0,0);stroke-width:0.5;fill-opacity:1;stroke-opacity:1;"/>
<line x1=5 y1=30 x2=25 y2=30 style="fill:rgb(128,0,128);stroke:rgb(128,0,128);stroke-width:3;fill-opacity:1;stroke-opacity:1;"/>
<line x1=0 y1=40 x2=30 y2=40 style="fill:none;stroke:rgb(255,0,0);stroke-width:0.5;fill-opacity:1;stroke-opacity:1;"/>
</svg>
</div>
<div class="main">
{svg}
</div>
</body>
</html>
"#);
eprintln!("num_images: {}", num_images);
eprintln!("num_deltas: {}", num_deltas);
Ok(())
}

View File

@@ -23,10 +23,11 @@ use pageserver::{
tenant::mgr,
virtual_file,
};
use postgres_backend::AuthType;
use utils::{
auth::JwtAuth,
logging, project_git_version,
logging,
postgres_backend::AuthType,
project_git_version,
sentry_init::init_sentry,
signals::{self, Signal},
tcp_listener,
@@ -270,43 +271,43 @@ fn start_pageserver(
WALRECEIVER_RUNTIME.block_on(pageserver::broker_client::init_broker_client(conf))?;
// Initialize authentication for incoming connections
let http_auth;
let pg_auth;
if conf.http_auth_type == AuthType::NeonJWT || conf.pg_auth_type == AuthType::NeonJWT {
// unwrap is ok because check is performed when creating config, so path is set and file exists
let key_path = conf.auth_validation_public_key_path.as_ref().unwrap();
info!(
"Loading public key for verifying JWT tokens from {:#?}",
key_path
);
let auth: Arc<JwtAuth> = Arc::new(JwtAuth::from_key_path(key_path)?);
let auth = match &conf.auth_type {
AuthType::Trust => None,
AuthType::NeonJWT => {
// unwrap is ok because check is performed when creating config, so path is set and file exists
let key_path = conf.auth_validation_public_key_path.as_ref().unwrap();
Some(JwtAuth::from_key_path(key_path)?.into())
}
};
info!("Using auth: {:#?}", conf.auth_type);
http_auth = match &conf.http_auth_type {
AuthType::Trust => None,
AuthType::NeonJWT => Some(auth.clone()),
};
pg_auth = match &conf.pg_auth_type {
AuthType::Trust => None,
AuthType::NeonJWT => Some(auth),
};
} else {
http_auth = None;
pg_auth = None;
}
info!("Using auth for http API: {:#?}", conf.http_auth_type);
info!("Using auth for pg connections: {:#?}", conf.pg_auth_type);
match var("NEON_AUTH_TOKEN") {
Ok(v) => {
// TODO: remove ZENITH_AUTH_TOKEN once it's not used anywhere in development/staging/prod configuration.
match (var("ZENITH_AUTH_TOKEN"), var("NEON_AUTH_TOKEN")) {
(old, Ok(v)) => {
info!("Loaded JWT token for authentication with Safekeeper");
if let Ok(v_old) = old {
warn!(
"JWT token for Safekeeper is specified twice, ZENITH_AUTH_TOKEN is deprecated"
);
if v_old != v {
warn!("JWT token for Safekeeper has two different values, choosing NEON_AUTH_TOKEN");
}
}
pageserver::config::SAFEKEEPER_AUTH_TOKEN
.set(Arc::new(v))
.map_err(|_| anyhow!("Could not initialize SAFEKEEPER_AUTH_TOKEN"))?;
}
Err(VarError::NotPresent) => {
(Ok(v), _) => {
info!("Loaded JWT token for authentication with Safekeeper");
warn!("Please update pageserver configuration: the JWT token should be NEON_AUTH_TOKEN, not ZENITH_AUTH_TOKEN");
pageserver::config::SAFEKEEPER_AUTH_TOKEN
.set(Arc::new(v))
.map_err(|_| anyhow!("Could not initialize SAFEKEEPER_AUTH_TOKEN"))?;
}
(_, Err(VarError::NotPresent)) => {
info!("No JWT token for authentication with Safekeeper detected");
}
Err(e) => {
(_, Err(e)) => {
return Err(e).with_context(|| {
"Failed to either load to detect non-present NEON_AUTH_TOKEN environment variable"
})
@@ -324,7 +325,7 @@ fn start_pageserver(
{
let _rt_guard = MGMT_REQUEST_RUNTIME.enter();
let router = http::make_router(conf, launch_ts, http_auth, remote_storage)?
let router = http::make_router(conf, launch_ts, auth.clone(), remote_storage)?
.build()
.map_err(|err| anyhow!(err))?;
let service = utils::http::RouterService::new(router).unwrap();
@@ -398,9 +399,9 @@ fn start_pageserver(
async move {
page_service::libpq_listener_main(
conf,
pg_auth,
auth,
pageserver_listener,
conf.pg_auth_type,
conf.auth_type,
libpq_ctx,
)
.await

View File

@@ -21,10 +21,10 @@ use std::time::Duration;
use toml_edit;
use toml_edit::{Document, Item};
use postgres_backend::AuthType;
use utils::{
id::{NodeId, TenantId, TimelineId},
logging::LogFormat,
postgres_backend::AuthType,
};
use crate::tenant::config::TenantConf;
@@ -61,7 +61,6 @@ pub mod defaults {
pub const DEFAULT_CACHED_METRIC_COLLECTION_INTERVAL: &str = "1 hour";
pub const DEFAULT_METRIC_COLLECTION_ENDPOINT: Option<reqwest::Url> = None;
pub const DEFAULT_SYNTHETIC_SIZE_CALCULATION_INTERVAL: &str = "10 min";
pub const DEFAULT_EVICTIONS_LOW_RESIDENCE_DURATION_METRIC_THRESHOLD: &str = "24 hour";
///
/// Default built-in configuration file.
@@ -90,8 +89,6 @@ pub mod defaults {
#cached_metric_collection_interval = '{DEFAULT_CACHED_METRIC_COLLECTION_INTERVAL}'
#synthetic_size_calculation_interval = '{DEFAULT_SYNTHETIC_SIZE_CALCULATION_INTERVAL}'
#evictions_low_residence_duration_metric_threshold = '{DEFAULT_EVICTIONS_LOW_RESIDENCE_DURATION_METRIC_THRESHOLD}'
# [tenant_config]
#checkpoint_distance = {DEFAULT_CHECKPOINT_DISTANCE} # in bytes
#checkpoint_timeout = {DEFAULT_CHECKPOINT_TIMEOUT}
@@ -121,9 +118,6 @@ pub struct PageServerConf {
/// Example (default): 127.0.0.1:9898
pub listen_http_addr: String,
/// Current availability zone. Used for traffic metrics.
pub availability_zone: Option<String>,
// Timeout when waiting for WAL receiver to catch up to an LSN given in a GetPage@LSN call.
pub wait_lsn_timeout: Duration,
// How long to wait for WAL redo to complete.
@@ -144,15 +138,9 @@ pub struct PageServerConf {
pub pg_distrib_dir: PathBuf,
// Authentication
/// authentication method for the HTTP mgmt API
pub http_auth_type: AuthType,
/// authentication method for libpq connections from compute
pub pg_auth_type: AuthType,
/// Path to a file containing public key for verifying JWT tokens.
/// Used for both mgmt and compute auth, if enabled.
pub auth_validation_public_key_path: Option<PathBuf>,
pub auth_type: AuthType,
pub auth_validation_public_key_path: Option<PathBuf>,
pub remote_storage_config: Option<RemoteStorageConfig>,
pub default_tenant_conf: TenantConf,
@@ -173,9 +161,6 @@ pub struct PageServerConf {
pub metric_collection_endpoint: Option<Url>,
pub synthetic_size_calculation_interval: Duration,
// See the corresponding metric's help string.
pub evictions_low_residence_duration_metric_threshold: Duration,
pub test_remote_failures: u64,
pub ondemand_download_behavior_treat_error_as_warn: bool,
@@ -211,8 +196,6 @@ struct PageServerConfigBuilder {
listen_http_addr: BuilderValue<String>,
availability_zone: BuilderValue<Option<String>>,
wait_lsn_timeout: BuilderValue<Duration>,
wal_redo_timeout: BuilderValue<Duration>,
@@ -225,8 +208,7 @@ struct PageServerConfigBuilder {
pg_distrib_dir: BuilderValue<PathBuf>,
http_auth_type: BuilderValue<AuthType>,
pg_auth_type: BuilderValue<AuthType>,
auth_type: BuilderValue<AuthType>,
//
auth_validation_public_key_path: BuilderValue<Option<PathBuf>>,
@@ -246,8 +228,6 @@ struct PageServerConfigBuilder {
metric_collection_endpoint: BuilderValue<Option<Url>>,
synthetic_size_calculation_interval: BuilderValue<Duration>,
evictions_low_residence_duration_metric_threshold: BuilderValue<Duration>,
test_remote_failures: BuilderValue<u64>,
ondemand_download_behavior_treat_error_as_warn: BuilderValue<bool>,
@@ -260,7 +240,6 @@ impl Default for PageServerConfigBuilder {
Self {
listen_pg_addr: Set(DEFAULT_PG_LISTEN_ADDR.to_string()),
listen_http_addr: Set(DEFAULT_HTTP_LISTEN_ADDR.to_string()),
availability_zone: Set(None),
wait_lsn_timeout: Set(humantime::parse_duration(DEFAULT_WAIT_LSN_TIMEOUT)
.expect("cannot parse default wait lsn timeout")),
wal_redo_timeout: Set(humantime::parse_duration(DEFAULT_WAL_REDO_TIMEOUT)
@@ -272,8 +251,7 @@ impl Default for PageServerConfigBuilder {
pg_distrib_dir: Set(env::current_dir()
.expect("cannot access current directory")
.join("pg_install")),
http_auth_type: Set(AuthType::Trust),
pg_auth_type: Set(AuthType::Trust),
auth_type: Set(AuthType::Trust),
auth_validation_public_key_path: Set(None),
remote_storage_config: Set(None),
id: NotSet,
@@ -301,11 +279,6 @@ impl Default for PageServerConfigBuilder {
.expect("cannot parse default synthetic size calculation interval")),
metric_collection_endpoint: Set(DEFAULT_METRIC_COLLECTION_ENDPOINT),
evictions_low_residence_duration_metric_threshold: Set(humantime::parse_duration(
DEFAULT_EVICTIONS_LOW_RESIDENCE_DURATION_METRIC_THRESHOLD,
)
.expect("cannot parse DEFAULT_EVICTIONS_LOW_RESIDENCE_DURATION_METRIC_THRESHOLD")),
test_remote_failures: Set(0),
ondemand_download_behavior_treat_error_as_warn: Set(false),
@@ -322,10 +295,6 @@ impl PageServerConfigBuilder {
self.listen_http_addr = BuilderValue::Set(listen_http_addr)
}
pub fn availability_zone(&mut self, availability_zone: Option<String>) {
self.availability_zone = BuilderValue::Set(availability_zone)
}
pub fn wait_lsn_timeout(&mut self, wait_lsn_timeout: Duration) {
self.wait_lsn_timeout = BuilderValue::Set(wait_lsn_timeout)
}
@@ -354,12 +323,8 @@ impl PageServerConfigBuilder {
self.pg_distrib_dir = BuilderValue::Set(pg_distrib_dir)
}
pub fn http_auth_type(&mut self, auth_type: AuthType) {
self.http_auth_type = BuilderValue::Set(auth_type)
}
pub fn pg_auth_type(&mut self, auth_type: AuthType) {
self.pg_auth_type = BuilderValue::Set(auth_type)
pub fn auth_type(&mut self, auth_type: AuthType) {
self.auth_type = BuilderValue::Set(auth_type)
}
pub fn auth_validation_public_key_path(
@@ -421,10 +386,6 @@ impl PageServerConfigBuilder {
self.test_remote_failures = BuilderValue::Set(fail_first);
}
pub fn evictions_low_residence_duration_metric_threshold(&mut self, value: Duration) {
self.evictions_low_residence_duration_metric_threshold = BuilderValue::Set(value);
}
pub fn ondemand_download_behavior_treat_error_as_warn(
&mut self,
ondemand_download_behavior_treat_error_as_warn: bool,
@@ -441,9 +402,6 @@ impl PageServerConfigBuilder {
listen_http_addr: self
.listen_http_addr
.ok_or(anyhow!("missing listen_http_addr"))?,
availability_zone: self
.availability_zone
.ok_or(anyhow!("missing availability_zone"))?,
wait_lsn_timeout: self
.wait_lsn_timeout
.ok_or(anyhow!("missing wait_lsn_timeout"))?,
@@ -461,10 +419,7 @@ impl PageServerConfigBuilder {
pg_distrib_dir: self
.pg_distrib_dir
.ok_or(anyhow!("missing pg_distrib_dir"))?,
http_auth_type: self
.http_auth_type
.ok_or(anyhow!("missing http_auth_type"))?,
pg_auth_type: self.pg_auth_type.ok_or(anyhow!("missing pg_auth_type"))?,
auth_type: self.auth_type.ok_or(anyhow!("missing auth_type"))?,
auth_validation_public_key_path: self
.auth_validation_public_key_path
.ok_or(anyhow!("missing auth_validation_public_key_path"))?,
@@ -498,11 +453,6 @@ impl PageServerConfigBuilder {
synthetic_size_calculation_interval: self
.synthetic_size_calculation_interval
.ok_or(anyhow!("missing synthetic_size_calculation_interval"))?,
evictions_low_residence_duration_metric_threshold: self
.evictions_low_residence_duration_metric_threshold
.ok_or(anyhow!(
"missing evictions_low_residence_duration_metric_threshold"
))?,
test_remote_failures: self
.test_remote_failures
.ok_or(anyhow!("missing test_remote_failuers"))?,
@@ -649,7 +599,6 @@ impl PageServerConf {
match key {
"listen_pg_addr" => builder.listen_pg_addr(parse_toml_string(key, item)?),
"listen_http_addr" => builder.listen_http_addr(parse_toml_string(key, item)?),
"availability_zone" => builder.availability_zone(Some(parse_toml_string(key, item)?)),
"wait_lsn_timeout" => builder.wait_lsn_timeout(parse_toml_duration(key, item)?),
"wal_redo_timeout" => builder.wal_redo_timeout(parse_toml_duration(key, item)?),
"initial_superuser_name" => builder.superuser(parse_toml_string(key, item)?),
@@ -663,8 +612,7 @@ impl PageServerConf {
"auth_validation_public_key_path" => builder.auth_validation_public_key_path(Some(
PathBuf::from(parse_toml_string(key, item)?),
)),
"http_auth_type" => builder.http_auth_type(parse_toml_from_str(key, item)?),
"pg_auth_type" => builder.pg_auth_type(parse_toml_from_str(key, item)?),
"auth_type" => builder.auth_type(parse_toml_from_str(key, item)?),
"remote_storage" => {
builder.remote_storage_config(RemoteStorageConfig::from_toml(item)?)
}
@@ -692,7 +640,6 @@ impl PageServerConf {
"synthetic_size_calculation_interval" =>
builder.synthetic_size_calculation_interval(parse_toml_duration(key, item)?),
"test_remote_failures" => builder.test_remote_failures(parse_toml_u64(key, item)?),
"evictions_low_residence_duration_metric_threshold" => builder.evictions_low_residence_duration_metric_threshold(parse_toml_duration(key, item)?),
"ondemand_download_behavior_treat_error_as_warn" => builder.ondemand_download_behavior_treat_error_as_warn(parse_toml_bool(key, item)?),
_ => bail!("unrecognized pageserver option '{key}'"),
}
@@ -700,7 +647,7 @@ impl PageServerConf {
let mut conf = builder.build().context("invalid config")?;
if conf.http_auth_type == AuthType::NeonJWT || conf.pg_auth_type == AuthType::NeonJWT {
if conf.auth_type == AuthType::NeonJWT {
let auth_validation_public_key_path = conf
.auth_validation_public_key_path
.get_or_insert_with(|| workdir.join("auth_public_key.pem"));
@@ -751,12 +698,6 @@ impl PageServerConf {
Some(parse_toml_u64("compaction_threshold", compaction_threshold)?.try_into()?);
}
if let Some(image_creation_threshold) = item.get("image_creation_threshold") {
t_conf.image_creation_threshold = Some(
parse_toml_u64("image_creation_threshold", image_creation_threshold)?.try_into()?,
);
}
if let Some(gc_horizon) = item.get("gc_horizon") {
t_conf.gc_horizon = Some(parse_toml_u64("gc_horizon", gc_horizon)?);
}
@@ -816,12 +757,10 @@ impl PageServerConf {
max_file_descriptors: defaults::DEFAULT_MAX_FILE_DESCRIPTORS,
listen_pg_addr: defaults::DEFAULT_PG_LISTEN_ADDR.to_string(),
listen_http_addr: defaults::DEFAULT_HTTP_LISTEN_ADDR.to_string(),
availability_zone: None,
superuser: "cloud_admin".to_string(),
workdir: repo_dir,
pg_distrib_dir,
http_auth_type: AuthType::Trust,
pg_auth_type: AuthType::Trust,
auth_type: AuthType::Trust,
auth_validation_public_key_path: None,
remote_storage_config: None,
default_tenant_conf: TenantConf::default(),
@@ -833,10 +772,6 @@ impl PageServerConf {
cached_metric_collection_interval: Duration::from_secs(60 * 60),
metric_collection_endpoint: defaults::DEFAULT_METRIC_COLLECTION_ENDPOINT,
synthetic_size_calculation_interval: Duration::from_secs(60),
evictions_low_residence_duration_metric_threshold: humantime::parse_duration(
defaults::DEFAULT_EVICTIONS_LOW_RESIDENCE_DURATION_METRIC_THRESHOLD,
)
.unwrap(),
test_remote_failures: 0,
ondemand_download_behavior_treat_error_as_warn: false,
}
@@ -978,9 +913,6 @@ metric_collection_interval = '222 s'
cached_metric_collection_interval = '22200 s'
metric_collection_endpoint = 'http://localhost:80/metrics'
synthetic_size_calculation_interval = '333 s'
evictions_low_residence_duration_metric_threshold = '444 s'
log_format = 'json'
"#;
@@ -1006,7 +938,6 @@ log_format = 'json'
id: NodeId(10),
listen_pg_addr: defaults::DEFAULT_PG_LISTEN_ADDR.to_string(),
listen_http_addr: defaults::DEFAULT_HTTP_LISTEN_ADDR.to_string(),
availability_zone: None,
wait_lsn_timeout: humantime::parse_duration(defaults::DEFAULT_WAIT_LSN_TIMEOUT)?,
wal_redo_timeout: humantime::parse_duration(defaults::DEFAULT_WAL_REDO_TIMEOUT)?,
superuser: defaults::DEFAULT_SUPERUSER.to_string(),
@@ -1014,8 +945,7 @@ log_format = 'json'
max_file_descriptors: defaults::DEFAULT_MAX_FILE_DESCRIPTORS,
workdir,
pg_distrib_dir,
http_auth_type: AuthType::Trust,
pg_auth_type: AuthType::Trust,
auth_type: AuthType::Trust,
auth_validation_public_key_path: None,
remote_storage_config: None,
default_tenant_conf: TenantConf::default(),
@@ -1035,9 +965,6 @@ log_format = 'json'
synthetic_size_calculation_interval: humantime::parse_duration(
defaults::DEFAULT_SYNTHETIC_SIZE_CALCULATION_INTERVAL
)?,
evictions_low_residence_duration_metric_threshold: humantime::parse_duration(
defaults::DEFAULT_EVICTIONS_LOW_RESIDENCE_DURATION_METRIC_THRESHOLD
)?,
test_remote_failures: 0,
ondemand_download_behavior_treat_error_as_warn: false,
},
@@ -1068,7 +995,6 @@ log_format = 'json'
id: NodeId(10),
listen_pg_addr: "127.0.0.1:64000".to_string(),
listen_http_addr: "127.0.0.1:9898".to_string(),
availability_zone: None,
wait_lsn_timeout: Duration::from_secs(111),
wal_redo_timeout: Duration::from_secs(111),
superuser: "zzzz".to_string(),
@@ -1076,8 +1002,7 @@ log_format = 'json'
max_file_descriptors: 333,
workdir,
pg_distrib_dir,
http_auth_type: AuthType::Trust,
pg_auth_type: AuthType::Trust,
auth_type: AuthType::Trust,
auth_validation_public_key_path: None,
remote_storage_config: None,
default_tenant_conf: TenantConf::default(),
@@ -1089,7 +1014,6 @@ log_format = 'json'
cached_metric_collection_interval: Duration::from_secs(22200),
metric_collection_endpoint: Some(Url::parse("http://localhost:80/metrics")?),
synthetic_size_calculation_interval: Duration::from_secs(333),
evictions_low_residence_duration_metric_threshold: Duration::from_secs(444),
test_remote_failures: 0,
ondemand_download_behavior_treat_error_as_warn: false,
},

View File

@@ -245,53 +245,6 @@ paths:
application/json:
schema:
$ref: "#/components/schemas/Error"
/v1/tenant/{tenant_id}/timeline/{timeline_id}/do_gc:
parameters:
- name: tenant_id
in: path
required: true
schema:
type: string
format: hex
- name: timeline_id
in: path
required: true
schema:
type: string
format: hex
put:
description: Garbage collect given timeline
responses:
"200":
description: OK
content:
application/json:
schema:
type: string
"400":
description: Error when no tenant id found in path, no timeline id or invalid timestamp
content:
application/json:
schema:
$ref: "#/components/schemas/Error"
"401":
description: Unauthorized Error
content:
application/json:
schema:
$ref: "#/components/schemas/UnauthorizedError"
"403":
description: Forbidden Error
content:
application/json:
schema:
$ref: "#/components/schemas/ForbiddenError"
"500":
description: Generic operation error
content:
application/json:
schema:
$ref: "#/components/schemas/Error"
/v1/tenant/{tenant_id}/attach:
parameters:
- name: tenant_id
@@ -351,13 +304,6 @@ paths:
schema:
type: string
format: hex
- name: detach_ignored
in: query
required: false
schema:
type: boolean
description: |
When true, allow to detach a tenant which state is ignored.
post:
description: |
Remove tenant data (including all corresponding timelines) from pageserver's memory and file system.

View File

@@ -10,7 +10,6 @@ use remote_storage::GenericRemoteStorage;
use tenant_size_model::{SizeResult, StorageModel};
use tokio_util::sync::CancellationToken;
use tracing::*;
use utils::http::endpoint::RequestSpan;
use utils::http::request::{get_request_param, must_get_query_param, parse_query_param};
use super::models::{
@@ -21,7 +20,7 @@ use crate::context::{DownloadBehavior, RequestContext};
use crate::pgdatadir_mapping::LsnForTimestamp;
use crate::task_mgr::TaskKind;
use crate::tenant::config::TenantConfOpt;
use crate::tenant::mgr::{TenantMapInsertError, TenantStateError};
use crate::tenant::mgr::TenantMapInsertError;
use crate::tenant::size::ModelInputs;
use crate::tenant::storage_layer::LayerAccessStatsReset;
use crate::tenant::{PageReconstructError, Timeline};
@@ -82,52 +81,38 @@ fn get_config(request: &Request<Body>) -> &'static PageServerConf {
get_state(request).conf
}
/// Check that the requester is authorized to operate on given tenant
fn check_permission(request: &Request<Body>, tenant_id: Option<TenantId>) -> Result<(), ApiError> {
check_permission_with(request, |claims| {
crate::auth::check_permission(claims, tenant_id)
})
}
impl From<PageReconstructError> for ApiError {
fn from(pre: PageReconstructError) -> ApiError {
match pre {
PageReconstructError::Other(pre) => ApiError::InternalServerError(pre),
PageReconstructError::NeedsDownload(_, _) => {
// This shouldn't happen, because we use a RequestContext that requests to
// download any missing layer files on-demand.
ApiError::InternalServerError(anyhow::anyhow!("need to download remote layer file"))
}
PageReconstructError::Cancelled => {
ApiError::InternalServerError(anyhow::anyhow!("request was cancelled"))
}
PageReconstructError::WalRedo(pre) => {
ApiError::InternalServerError(anyhow::Error::new(pre))
}
fn apierror_from_prerror(err: PageReconstructError) -> ApiError {
match err {
PageReconstructError::Other(err) => ApiError::InternalServerError(err),
PageReconstructError::NeedsDownload(_, _) => {
// This shouldn't happen, because we use a RequestContext that requests to
// download any missing layer files on-demand.
ApiError::InternalServerError(anyhow::anyhow!("need to download remote layer file"))
}
PageReconstructError::Cancelled => {
ApiError::InternalServerError(anyhow::anyhow!("request was cancelled"))
}
PageReconstructError::WalRedo(err) => {
ApiError::InternalServerError(anyhow::Error::new(err))
}
}
}
impl From<TenantMapInsertError> for ApiError {
fn from(tmie: TenantMapInsertError) -> ApiError {
match tmie {
TenantMapInsertError::StillInitializing | TenantMapInsertError::ShuttingDown => {
ApiError::InternalServerError(anyhow::Error::new(tmie))
}
TenantMapInsertError::TenantAlreadyExists(id, state) => {
ApiError::Conflict(format!("tenant {id} already exists, state: {state:?}"))
}
TenantMapInsertError::Closure(e) => ApiError::InternalServerError(e),
fn apierror_from_tenant_map_insert_error(e: TenantMapInsertError) -> ApiError {
match e {
TenantMapInsertError::StillInitializing | TenantMapInsertError::ShuttingDown => {
ApiError::InternalServerError(anyhow::Error::new(e))
}
}
}
impl From<TenantStateError> for ApiError {
fn from(tse: TenantStateError) -> ApiError {
match tse {
TenantStateError::NotFound(tid) => ApiError::NotFound(anyhow!("tenant {}", tid)),
_ => ApiError::InternalServerError(anyhow::Error::new(tse)),
TenantMapInsertError::TenantAlreadyExists(id, state) => {
ApiError::Conflict(format!("tenant {id} already exists, state: {state:?}"))
}
TenantMapInsertError::Closure(e) => ApiError::InternalServerError(e),
}
}
@@ -185,7 +170,7 @@ fn build_timeline_info_common(
None
}
};
let current_physical_size = Some(timeline.layer_size_sum());
let current_physical_size = Some(timeline.layer_size_sum().approximate_is_ok());
let state = timeline.current_state();
let remote_consistent_lsn = timeline.get_remote_consistent_lsn().unwrap_or(Lsn(0));
@@ -231,7 +216,9 @@ async fn timeline_create_handler(mut request: Request<Body>) -> Result<Response<
let ctx = RequestContext::new(TaskKind::MgmtRequest, DownloadBehavior::Error);
let tenant = mgr::get_tenant(tenant_id, true).await?;
let tenant = mgr::get_tenant(tenant_id, true)
.await
.map_err(ApiError::NotFound)?;
match tenant.create_timeline(
new_timeline_id,
request_data.ancestor_timeline_id.map(TimelineId::from),
@@ -261,7 +248,9 @@ async fn timeline_list_handler(request: Request<Body>) -> Result<Response<Body>,
let ctx = RequestContext::new(TaskKind::MgmtRequest, DownloadBehavior::Download);
let response_data = async {
let tenant = mgr::get_tenant(tenant_id, true).await?;
let tenant = mgr::get_tenant(tenant_id, true)
.await
.map_err(ApiError::NotFound)?;
let timelines = tenant.list_timelines();
let mut response_data = Vec::with_capacity(timelines.len());
@@ -277,7 +266,7 @@ async fn timeline_list_handler(request: Request<Body>) -> Result<Response<Body>,
response_data.push(timeline_info);
}
Ok::<Vec<TimelineInfo>, ApiError>(response_data)
Ok(response_data)
}
.instrument(info_span!("timeline_list", tenant = %tenant_id))
.await?;
@@ -296,7 +285,9 @@ async fn timeline_detail_handler(request: Request<Body>) -> Result<Response<Body
let ctx = RequestContext::new(TaskKind::MgmtRequest, DownloadBehavior::Download);
let timeline_info = async {
let tenant = mgr::get_tenant(tenant_id, true).await?;
let tenant = mgr::get_tenant(tenant_id, true)
.await
.map_err(ApiError::NotFound)?;
let timeline = tenant
.get_timeline(timeline_id, false)
@@ -332,7 +323,10 @@ async fn get_lsn_by_timestamp_handler(request: Request<Body>) -> Result<Response
let ctx = RequestContext::new(TaskKind::MgmtRequest, DownloadBehavior::Download);
let timeline = active_timeline_of_active_tenant(tenant_id, timeline_id).await?;
let result = timeline.find_lsn_for_timestamp(timestamp_pg, &ctx).await?;
let result = timeline
.find_lsn_for_timestamp(timestamp_pg, &ctx)
.await
.map_err(apierror_from_prerror)?;
let result = match result {
LsnForTimestamp::Present(lsn) => format!("{lsn}"),
@@ -357,7 +351,8 @@ async fn tenant_attach_handler(request: Request<Body>) -> Result<Response<Body>,
if let Some(remote_storage) = &state.remote_storage {
mgr::attach_tenant(state.conf, tenant_id, remote_storage.clone(), &ctx)
.instrument(info_span!("tenant_attach", tenant = %tenant_id))
.await?;
.await
.map_err(apierror_from_tenant_map_insert_error)?;
} else {
return Err(ApiError::BadRequest(anyhow!(
"attach_tenant is not possible because pageserver was configured without remote storage"
@@ -376,7 +371,11 @@ async fn timeline_delete_handler(request: Request<Body>) -> Result<Response<Body
mgr::delete_timeline(tenant_id, timeline_id, &ctx)
.instrument(info_span!("timeline_delete", tenant = %tenant_id, timeline = %timeline_id))
.await?;
.await
// FIXME: Errors from `delete_timeline` can occur for a number of reasons, incuding both
// user and internal errors. Replace this with better handling once the error type permits
// it.
.map_err(ApiError::InternalServerError)?;
json_response(StatusCode::OK, ())
}
@@ -384,13 +383,15 @@ async fn timeline_delete_handler(request: Request<Body>) -> Result<Response<Body
async fn tenant_detach_handler(request: Request<Body>) -> Result<Response<Body>, ApiError> {
let tenant_id: TenantId = parse_request_param(&request, "tenant_id")?;
check_permission(&request, Some(tenant_id))?;
let detach_ignored: Option<bool> = parse_query_param(&request, "detach_ignored")?;
let state = get_state(&request);
let conf = state.conf;
mgr::detach_tenant(conf, tenant_id, detach_ignored.unwrap_or(false))
mgr::detach_tenant(conf, tenant_id)
.instrument(info_span!("tenant_detach", tenant = %tenant_id))
.await?;
.await
// FIXME: Errors from `detach_tenant` can be caused by both both user and internal errors.
// Replace this with better handling once the error type permits it.
.map_err(ApiError::InternalServerError)?;
json_response(StatusCode::OK, ())
}
@@ -404,7 +405,8 @@ async fn tenant_load_handler(request: Request<Body>) -> Result<Response<Body>, A
let state = get_state(&request);
mgr::load_tenant(state.conf, tenant_id, state.remote_storage.clone(), &ctx)
.instrument(info_span!("load", tenant = %tenant_id))
.await?;
.await
.map_err(apierror_from_tenant_map_insert_error)?;
json_response(StatusCode::ACCEPTED, ())
}
@@ -417,7 +419,10 @@ async fn tenant_ignore_handler(request: Request<Body>) -> Result<Response<Body>,
let conf = state.conf;
mgr::ignore_tenant(conf, tenant_id)
.instrument(info_span!("ignore_tenant", tenant = %tenant_id))
.await?;
.await
// FIXME: Errors from `ignore_tenant` can be caused by both both user and internal errors.
// Replace this with better handling once the error type permits it.
.map_err(ApiError::InternalServerError)?;
json_response(StatusCode::OK, ())
}
@@ -452,7 +457,7 @@ async fn tenant_status(request: Request<Body>) -> Result<Response<Body>, ApiErro
// Calculate total physical size of all timelines
let mut current_physical_size = 0;
for timeline in tenant.list_timelines().iter() {
current_physical_size += timeline.layer_size_sum();
current_physical_size += timeline.layer_size_sum().approximate_is_ok();
}
let state = tenant.current_state();
@@ -491,7 +496,9 @@ async fn tenant_size_handler(request: Request<Body>) -> Result<Response<Body>, A
let headers = request.headers();
let ctx = RequestContext::new(TaskKind::MgmtRequest, DownloadBehavior::Download);
let tenant = mgr::get_tenant(tenant_id, true).await?;
let tenant = mgr::get_tenant(tenant_id, true)
.await
.map_err(ApiError::InternalServerError)?;
// this can be long operation
let inputs = tenant
@@ -739,14 +746,6 @@ async fn tenant_create_handler(mut request: Request<Body>) -> Result<Response<Bo
);
}
if let Some(eviction_policy) = request_data.eviction_policy {
tenant_conf.eviction_policy = Some(
serde_json::from_value(eviction_policy)
.context("parse field `eviction_policy`")
.map_err(ApiError::BadRequest)?,
);
}
let target_tenant_id = request_data
.new_tenant_id
.map(TenantId::from)
@@ -762,7 +761,8 @@ async fn tenant_create_handler(mut request: Request<Body>) -> Result<Response<Bo
&ctx,
)
.instrument(info_span!("tenant_create", tenant = ?target_tenant_id))
.await?;
.await
.map_err(apierror_from_tenant_map_insert_error)?;
// We created the tenant. Existing API semantics are that the tenant
// is Active when this function returns.
@@ -786,7 +786,9 @@ async fn get_tenant_config_handler(request: Request<Body>) -> Result<Response<Bo
let tenant_id: TenantId = parse_request_param(&request, "tenant_id")?;
check_permission(&request, Some(tenant_id))?;
let tenant = mgr::get_tenant(tenant_id, false).await?;
let tenant = mgr::get_tenant(tenant_id, false)
.await
.map_err(ApiError::NotFound)?;
let response = HashMap::from([
(
@@ -881,7 +883,10 @@ async fn update_tenant_config_handler(
let state = get_state(&request);
mgr::set_new_tenant_config(state.conf, tenant_conf, tenant_id)
.instrument(info_span!("tenant_config", tenant = ?tenant_id))
.await?;
.await
// FIXME: `update_tenant_config` can fail because of both user and internal errors.
// Replace this `map_err` with better error handling once the type permits it
.map_err(ApiError::InternalServerError)?;
json_response(StatusCode::OK, ())
}
@@ -1018,7 +1023,9 @@ async fn active_timeline_of_active_tenant(
tenant_id: TenantId,
timeline_id: TimelineId,
) -> Result<Arc<Timeline>, ApiError> {
let tenant = mgr::get_tenant(tenant_id, true).await?;
let tenant = mgr::get_tenant(tenant_id, true)
.await
.map_err(ApiError::NotFound)?;
tenant
.get_timeline(timeline_id, true)
.map_err(ApiError::NotFound)
@@ -1084,8 +1091,7 @@ pub fn make_router(
let handler = $handler;
#[cfg(not(feature = "testing"))]
let handler = cfg_disabled;
move |r| RequestSpan(handler).handle(r)
handler
}};
}
@@ -1093,55 +1099,35 @@ pub fn make_router(
.data(Arc::new(
State::new(conf, auth, remote_storage).context("Failed to initialize router state")?,
))
.get("/v1/status", |r| RequestSpan(status_handler).handle(r))
.get("/v1/status", status_handler)
.put(
"/v1/failpoints",
testing_api!("manage failpoints", failpoints_handler),
)
.get("/v1/tenant", |r| RequestSpan(tenant_list_handler).handle(r))
.post("/v1/tenant", |r| {
RequestSpan(tenant_create_handler).handle(r)
})
.get("/v1/tenant/:tenant_id", |r| {
RequestSpan(tenant_status).handle(r)
})
.get("/v1/tenant/:tenant_id/synthetic_size", |r| {
RequestSpan(tenant_size_handler).handle(r)
})
.put("/v1/tenant/config", |r| {
RequestSpan(update_tenant_config_handler).handle(r)
})
.get("/v1/tenant/:tenant_id/config", |r| {
RequestSpan(get_tenant_config_handler).handle(r)
})
.get("/v1/tenant/:tenant_id/timeline", |r| {
RequestSpan(timeline_list_handler).handle(r)
})
.post("/v1/tenant/:tenant_id/timeline", |r| {
RequestSpan(timeline_create_handler).handle(r)
})
.post("/v1/tenant/:tenant_id/attach", |r| {
RequestSpan(tenant_attach_handler).handle(r)
})
.post("/v1/tenant/:tenant_id/detach", |r| {
RequestSpan(tenant_detach_handler).handle(r)
})
.post("/v1/tenant/:tenant_id/load", |r| {
RequestSpan(tenant_load_handler).handle(r)
})
.post("/v1/tenant/:tenant_id/ignore", |r| {
RequestSpan(tenant_ignore_handler).handle(r)
})
.get("/v1/tenant/:tenant_id/timeline/:timeline_id", |r| {
RequestSpan(timeline_detail_handler).handle(r)
})
.get("/v1/tenant", tenant_list_handler)
.post("/v1/tenant", tenant_create_handler)
.get("/v1/tenant/:tenant_id", tenant_status)
.get("/v1/tenant/:tenant_id/synthetic_size", tenant_size_handler)
.put("/v1/tenant/config", update_tenant_config_handler)
.get("/v1/tenant/:tenant_id/config", get_tenant_config_handler)
.get("/v1/tenant/:tenant_id/timeline", timeline_list_handler)
.post("/v1/tenant/:tenant_id/timeline", timeline_create_handler)
.post("/v1/tenant/:tenant_id/attach", tenant_attach_handler)
.post("/v1/tenant/:tenant_id/detach", tenant_detach_handler)
.post("/v1/tenant/:tenant_id/load", tenant_load_handler)
.post("/v1/tenant/:tenant_id/ignore", tenant_ignore_handler)
.get(
"/v1/tenant/:tenant_id/timeline/:timeline_id",
timeline_detail_handler,
)
.get(
"/v1/tenant/:tenant_id/timeline/:timeline_id/get_lsn_by_timestamp",
|r| RequestSpan(get_lsn_by_timestamp_handler).handle(r),
get_lsn_by_timestamp_handler,
)
.put(
"/v1/tenant/:tenant_id/timeline/:timeline_id/do_gc",
timeline_gc_handler,
)
.put("/v1/tenant/:tenant_id/timeline/:timeline_id/do_gc", |r| {
RequestSpan(timeline_gc_handler).handle(r)
})
.put(
"/v1/tenant/:tenant_id/timeline/:timeline_id/compact",
testing_api!("run timeline compaction", timeline_compact_handler),
@@ -1152,26 +1138,28 @@ pub fn make_router(
)
.post(
"/v1/tenant/:tenant_id/timeline/:timeline_id/download_remote_layers",
|r| RequestSpan(timeline_download_remote_layers_handler_post).handle(r),
timeline_download_remote_layers_handler_post,
)
.get(
"/v1/tenant/:tenant_id/timeline/:timeline_id/download_remote_layers",
|r| RequestSpan(timeline_download_remote_layers_handler_get).handle(r),
timeline_download_remote_layers_handler_get,
)
.delete(
"/v1/tenant/:tenant_id/timeline/:timeline_id",
timeline_delete_handler,
)
.get(
"/v1/tenant/:tenant_id/timeline/:timeline_id/layer",
layer_map_info_handler,
)
.delete("/v1/tenant/:tenant_id/timeline/:timeline_id", |r| {
RequestSpan(timeline_delete_handler).handle(r)
})
.get("/v1/tenant/:tenant_id/timeline/:timeline_id/layer", |r| {
RequestSpan(layer_map_info_handler).handle(r)
})
.get(
"/v1/tenant/:tenant_id/timeline/:timeline_id/layer/:layer_file_name",
|r| RequestSpan(layer_download_handler).handle(r),
layer_download_handler,
)
.delete(
"/v1/tenant/:tenant_id/timeline/:timeline_id/layer/:layer_file_name",
|r| RequestSpan(evict_timeline_layer_handler).handle(r),
evict_timeline_layer_handler,
)
.get("/v1/panic", |r| RequestSpan(always_panic_handler).handle(r))
.get("/v1/panic", always_panic_handler)
.any(handler_404))
}

View File

@@ -1,12 +1,11 @@
use crate::repository::{key_range_size, singleton_range, Key};
use postgres_ffi::BLCKSZ;
use std::ops::Range;
use tracing::debug;
///
/// Represents a set of Keys, in a compact form.
///
#[derive(Clone, Debug, Default)]
#[derive(Clone, Debug)]
pub struct KeySpace {
/// Contiguous ranges of keys that belong to the key space. In key order,
/// and with no overlap.
@@ -62,60 +61,6 @@ impl KeySpace {
KeyPartitioning { parts }
}
/// Add range to keyspace.
///
/// Unlike KeySpaceAccum, it accepts key ranges in any order and overlapping ranges.
pub fn add_range(&mut self, range: Range<Key>) {
let start = range.start;
let mut end = range.end;
let mut prev_index = match self.ranges.binary_search_by_key(&end, |r| r.start) {
Ok(index) => index,
Err(0) => {
self.ranges.insert(0, range);
return;
}
Err(index) => index - 1,
};
loop {
let mut prev = &mut self.ranges[prev_index];
if prev.end >= start {
// two ranges overlap
if prev.start <= start {
// combine with prev range
if prev.end < end {
prev.end = end;
debug!("Extend wanted image {}..{}", prev.start, end);
}
return;
} else {
if prev.end > end {
end = prev.end;
}
self.ranges.remove(prev_index);
}
} else {
break;
}
if prev_index == 0 {
break;
}
prev_index -= 1;
}
debug!("Wanted image {}..{}", start, end);
self.ranges.insert(prev_index, start..end);
}
///
/// Check if key space contains overlapping range
///
pub fn overlaps(&self, range: &Range<Key>) -> bool {
match self.ranges.binary_search_by_key(&range.end, |r| r.start) {
Ok(_) => false,
Err(0) => false,
Err(index) => self.ranges[index - 1].end > range.start,
}
}
}
///

View File

@@ -9,18 +9,22 @@ use once_cell::sync::Lazy;
use pageserver_api::models::state;
use utils::id::{TenantId, TimelineId};
/// Prometheus histogram buckets (in seconds) for operations in the critical
/// path. In other words, operations that directly affect that latency of user
/// queries.
///
/// The buckets capture the majority of latencies in the microsecond and
/// millisecond range but also extend far enough up to distinguish "bad" from
/// "really bad".
const CRITICAL_OP_BUCKETS: &[f64] = &[
0.000_001, 0.000_010, 0.000_100, // 1 us, 10 us, 100 us
0.001_000, 0.010_000, 0.100_000, // 1 ms, 10 ms, 100 ms
1.0, 10.0, 100.0, // 1 s, 10 s, 100 s
];
/// Prometheus histogram buckets (in seconds) that capture the majority of
/// latencies in the microsecond range but also extend far enough up to distinguish
/// "bad" from "really bad".
fn get_buckets_for_critical_operations() -> Vec<f64> {
let buckets_per_digit = 5;
let min_exponent = -6;
let max_exponent = 2;
let mut buckets = vec![];
// Compute 10^(exp / buckets_per_digit) instead of 10^(1/buckets_per_digit)^exp
// because it's more numerically stable and doesn't result in numbers like 9.999999
for exp in (min_exponent * buckets_per_digit)..=(max_exponent * buckets_per_digit) {
buckets.push(10_f64.powf(exp as f64 / buckets_per_digit as f64))
}
buckets
}
// Metrics collected on operations on the storage repository.
const STORAGE_TIME_OPERATIONS: &[&str] = &[
@@ -51,15 +55,12 @@ pub static STORAGE_TIME_COUNT_PER_TIMELINE: Lazy<IntCounterVec> = Lazy::new(|| {
.expect("failed to define a metric")
});
// Buckets for background operations like compaction, GC, size calculation
const STORAGE_OP_BUCKETS: &[f64] = &[0.010, 0.100, 1.0, 10.0, 100.0, 1000.0];
pub static STORAGE_TIME_GLOBAL: Lazy<HistogramVec> = Lazy::new(|| {
register_histogram_vec!(
"pageserver_storage_operations_seconds_global",
"Time spent on storage operations",
&["operation"],
STORAGE_OP_BUCKETS.into(),
get_buckets_for_critical_operations(),
)
.expect("failed to define a metric")
});
@@ -70,7 +71,7 @@ static RECONSTRUCT_TIME: Lazy<HistogramVec> = Lazy::new(|| {
"pageserver_getpage_reconstruct_seconds",
"Time spent in reconstruct_value",
&["tenant_id", "timeline_id"],
CRITICAL_OP_BUCKETS.into(),
get_buckets_for_critical_operations(),
)
.expect("failed to define a metric")
});
@@ -89,7 +90,7 @@ static WAIT_LSN_TIME: Lazy<HistogramVec> = Lazy::new(|| {
"pageserver_wait_lsn_seconds",
"Time spent waiting for WAL to arrive",
&["tenant_id", "timeline_id"],
CRITICAL_OP_BUCKETS.into(),
get_buckets_for_critical_operations(),
)
.expect("failed to define a metric")
});
@@ -122,22 +123,6 @@ static REMOTE_PHYSICAL_SIZE: Lazy<UIntGaugeVec> = Lazy::new(|| {
.expect("failed to define a metric")
});
pub static REMOTE_ONDEMAND_DOWNLOADED_LAYERS: Lazy<IntCounter> = Lazy::new(|| {
register_int_counter!(
"pageserver_remote_ondemand_downloaded_layers_total",
"Total on-demand downloaded layers"
)
.unwrap()
});
pub static REMOTE_ONDEMAND_DOWNLOADED_BYTES: Lazy<IntCounter> = Lazy::new(|| {
register_int_counter!(
"pageserver_remote_ondemand_downloaded_bytes_total",
"Total bytes of layers on-demand downloaded",
)
.unwrap()
});
static CURRENT_LOGICAL_SIZE: Lazy<UIntGaugeVec> = Lazy::new(|| {
register_uint_gauge_vec!(
"pageserver_current_logical_size",
@@ -194,101 +179,15 @@ static PERSISTENT_BYTES_WRITTEN: Lazy<IntCounterVec> = Lazy::new(|| {
.expect("failed to define a metric")
});
static EVICTIONS: Lazy<IntCounterVec> = Lazy::new(|| {
register_int_counter_vec!(
"pageserver_evictions",
"Number of layers evicted from the pageserver",
&["tenant_id", "timeline_id"]
)
.expect("failed to define a metric")
});
static EVICTIONS_WITH_LOW_RESIDENCE_DURATION: Lazy<IntCounterVec> = Lazy::new(|| {
register_int_counter_vec!(
"pageserver_evictions_with_low_residence_duration",
"If a layer is evicted that was resident for less than `low_threshold`, it is counted to this counter. \
Residence duration is determined using the `residence_duration_data_source`.",
&["tenant_id", "timeline_id", "residence_duration_data_source", "low_threshold_secs"]
)
.expect("failed to define a metric")
});
/// Each [`Timeline`]'s [`EVICTIONS_WITH_LOW_RESIDENCE_DURATION`] metric.
#[derive(Debug)]
pub struct EvictionsWithLowResidenceDuration {
data_source: &'static str,
threshold: Duration,
counter: Option<IntCounter>,
}
pub struct EvictionsWithLowResidenceDurationBuilder {
data_source: &'static str,
threshold: Duration,
}
impl EvictionsWithLowResidenceDurationBuilder {
pub fn new(data_source: &'static str, threshold: Duration) -> Self {
Self {
data_source,
threshold,
}
}
fn build(&self, tenant_id: &str, timeline_id: &str) -> EvictionsWithLowResidenceDuration {
let counter = EVICTIONS_WITH_LOW_RESIDENCE_DURATION
.get_metric_with_label_values(&[
tenant_id,
timeline_id,
self.data_source,
&EvictionsWithLowResidenceDuration::threshold_label_value(self.threshold),
])
.unwrap();
EvictionsWithLowResidenceDuration {
data_source: self.data_source,
threshold: self.threshold,
counter: Some(counter),
}
}
}
impl EvictionsWithLowResidenceDuration {
fn threshold_label_value(threshold: Duration) -> String {
format!("{}", threshold.as_secs())
}
pub fn observe(&self, observed_value: Duration) {
if self.threshold < observed_value {
self.counter
.as_ref()
.expect("nobody calls this function after `remove_from_vec`")
.inc();
}
}
// This could be a `Drop` impl, but, we need the `tenant_id` and `timeline_id`.
fn remove(&mut self, tenant_id: &str, timeline_id: &str) {
let Some(_counter) = self.counter.take() else {
return;
};
EVICTIONS_WITH_LOW_RESIDENCE_DURATION
.remove_label_values(&[
tenant_id,
timeline_id,
self.data_source,
&Self::threshold_label_value(self.threshold),
])
.expect("we own the metric, no-one else should remove it");
}
}
// Metrics collected on disk IO operations
//
// Roughly logarithmic scale.
const STORAGE_IO_TIME_BUCKETS: &[f64] = &[
0.000030, // 30 usec
0.001000, // 1000 usec
0.030, // 30 ms
1.000, // 1000 ms
0.000001, // 1 usec
0.00001, // 10 usec
0.0001, // 100 usec
0.001, // 1 msec
0.01, // 10 msec
0.1, // 100 msec
1.0, // 1 sec
];
const STORAGE_IO_TIME_OPERATIONS: &[&str] = &[
@@ -323,12 +222,20 @@ const SMGR_QUERY_TIME_OPERATIONS: &[&str] = &[
"get_db_size",
];
const SMGR_QUERY_TIME_BUCKETS: &[f64] = &[
0.00001, // 1/100000 s
0.0001, 0.00015, 0.0002, 0.00025, 0.0003, 0.00035, 0.0005, 0.00075, // 1/10000 s
0.001, 0.0025, 0.005, 0.0075, // 1/1000 s
0.01, 0.0125, 0.015, 0.025, 0.05, // 1/100 s
0.1, // 1/10 s
];
pub static SMGR_QUERY_TIME: Lazy<HistogramVec> = Lazy::new(|| {
register_histogram_vec!(
"pageserver_smgr_query_seconds",
"Time spent on smgr query handling",
&["smgr_query_type", "tenant_id", "timeline_id"],
CRITICAL_OP_BUCKETS.into(),
SMGR_QUERY_TIME_BUCKETS.into()
)
.expect("failed to define a metric")
});
@@ -597,16 +504,10 @@ pub struct TimelineMetrics {
pub current_logical_size_gauge: UIntGauge,
pub num_persistent_files_created: IntCounter,
pub persistent_bytes_written: IntCounter,
pub evictions: IntCounter,
pub evictions_with_low_residence_duration: EvictionsWithLowResidenceDuration,
}
impl TimelineMetrics {
pub fn new(
tenant_id: &TenantId,
timeline_id: &TimelineId,
evictions_with_low_residence_duration_builder: EvictionsWithLowResidenceDurationBuilder,
) -> Self {
pub fn new(tenant_id: &TenantId, timeline_id: &TimelineId) -> Self {
let tenant_id = tenant_id.to_string();
let timeline_id = timeline_id.to_string();
let reconstruct_time_histo = RECONSTRUCT_TIME
@@ -643,11 +544,6 @@ impl TimelineMetrics {
let persistent_bytes_written = PERSISTENT_BYTES_WRITTEN
.get_metric_with_label_values(&[&tenant_id, &timeline_id])
.unwrap();
let evictions = EVICTIONS
.get_metric_with_label_values(&[&tenant_id, &timeline_id])
.unwrap();
let evictions_with_low_residence_duration =
evictions_with_low_residence_duration_builder.build(&tenant_id, &timeline_id);
TimelineMetrics {
tenant_id,
@@ -667,8 +563,6 @@ impl TimelineMetrics {
current_logical_size_gauge,
num_persistent_files_created,
persistent_bytes_written,
evictions,
evictions_with_low_residence_duration,
}
}
}
@@ -685,9 +579,7 @@ impl Drop for TimelineMetrics {
let _ = CURRENT_LOGICAL_SIZE.remove_label_values(&[tenant_id, timeline_id]);
let _ = NUM_PERSISTENT_FILES_CREATED.remove_label_values(&[tenant_id, timeline_id]);
let _ = PERSISTENT_BYTES_WRITTEN.remove_label_values(&[tenant_id, timeline_id]);
let _ = EVICTIONS.remove_label_values(&[tenant_id, timeline_id]);
self.evictions_with_low_residence_duration
.remove(tenant_id, timeline_id);
for op in STORAGE_TIME_OPERATIONS {
let _ =
STORAGE_TIME_SUM_PER_TIMELINE.remove_label_values(&[op, tenant_id, timeline_id]);
@@ -722,7 +614,7 @@ use std::collections::HashMap;
use std::pin::Pin;
use std::sync::{Arc, Mutex};
use std::task::{Context, Poll};
use std::time::{Duration, Instant};
use std::time::Instant;
pub struct RemoteTimelineClientMetrics {
tenant_id: String,

View File

@@ -12,7 +12,7 @@
use anyhow::Context;
use bytes::Buf;
use bytes::Bytes;
use futures::Stream;
use futures::{Stream, StreamExt};
use pageserver_api::models::TenantState;
use pageserver_api::models::{
PagestreamBeMessage, PagestreamDbSizeRequest, PagestreamDbSizeResponse,
@@ -20,9 +20,7 @@ use pageserver_api::models::{
PagestreamFeMessage, PagestreamGetPageRequest, PagestreamGetPageResponse,
PagestreamNblocksRequest, PagestreamNblocksResponse,
};
use postgres_backend::PostgresBackendTCP;
use postgres_backend::{self, is_expected_io_error, AuthType, PostgresBackend, QueryError};
use pq_proto::framed::ConnectionError;
use pq_proto::ConnectionError;
use pq_proto::FeStartupPacket;
use pq_proto::{BeMessage, FeMessage, RowDescriptor};
use std::io;
@@ -31,13 +29,14 @@ use std::str;
use std::str::FromStr;
use std::sync::Arc;
use std::time::Duration;
use tokio_util::io::StreamReader;
use tracing::*;
use utils::id::ConnectionId;
use utils::{
auth::{Claims, JwtAuth, Scope},
id::{TenantId, TimelineId},
lsn::Lsn,
postgres_backend::AuthType,
postgres_backend_async::{self, is_expected_io_error, PostgresBackend, QueryError},
simple_rcu::RcuReadGuard,
};
@@ -56,7 +55,7 @@ use crate::trace::Tracer;
use postgres_ffi::pg_constants::DEFAULTTABLESPACE_OID;
use postgres_ffi::BLCKSZ;
fn copyin_stream(pgb: &mut PostgresBackendTCP) -> impl Stream<Item = io::Result<Bytes>> + '_ {
fn copyin_stream(pgb: &mut PostgresBackend) -> impl Stream<Item = io::Result<Bytes>> + '_ {
async_stream::try_stream! {
loop {
let msg = tokio::select! {
@@ -65,11 +64,11 @@ fn copyin_stream(pgb: &mut PostgresBackendTCP) -> impl Stream<Item = io::Result<
_ = task_mgr::shutdown_watcher() => {
// We were requested to shut down.
let msg = format!("pageserver is shutting down");
let _ = pgb.write_message_noflush(&BeMessage::ErrorResponse(&msg, None));
let _ = pgb.write_message(&BeMessage::ErrorResponse(&msg, None));
Err(QueryError::Other(anyhow::anyhow!(msg)))
}
msg = pgb.read_message() => { msg.map_err(QueryError::from)}
msg = pgb.read_message() => { msg }
};
match msg {
@@ -80,16 +79,14 @@ fn copyin_stream(pgb: &mut PostgresBackendTCP) -> impl Stream<Item = io::Result<
FeMessage::Sync => continue,
FeMessage::Terminate => {
let msg = "client terminated connection with Terminate message during COPY";
let query_error = QueryError::Disconnected(ConnectionError::Io(io::Error::new(io::ErrorKind::ConnectionReset, msg)));
// error can't happen here, ErrorResponse serialization should be always ok
pgb.write_message_noflush(&BeMessage::ErrorResponse(msg, Some(query_error.pg_error_code()))).map_err(|e| e.into_io_error())?;
let query_error_error = QueryError::Disconnected(ConnectionError::Socket(io::Error::new(io::ErrorKind::ConnectionReset, msg)));
pgb.write_message(&BeMessage::ErrorResponse(msg, Some(query_error_error.pg_error_code())))?;
Err(io::Error::new(io::ErrorKind::ConnectionReset, msg))?;
break;
}
m => {
let msg = format!("unexpected message {m:?}");
// error can't happen here, ErrorResponse serialization should be always ok
pgb.write_message_noflush(&BeMessage::ErrorResponse(&msg, None)).map_err(|e| e.into_io_error())?;
pgb.write_message(&BeMessage::ErrorResponse(&msg, None))?;
Err(io::Error::new(io::ErrorKind::Other, msg))?;
break;
}
@@ -99,66 +96,22 @@ fn copyin_stream(pgb: &mut PostgresBackendTCP) -> impl Stream<Item = io::Result<
}
Ok(None) => {
let msg = "client closed connection during COPY";
let query_error = QueryError::Disconnected(ConnectionError::Io(io::Error::new(io::ErrorKind::ConnectionReset, msg)));
// error can't happen here, ErrorResponse serialization should be always ok
pgb.write_message_noflush(&BeMessage::ErrorResponse(msg, Some(query_error.pg_error_code()))).map_err(|e| e.into_io_error())?;
let query_error_error = QueryError::Disconnected(ConnectionError::Socket(io::Error::new(io::ErrorKind::ConnectionReset, msg)));
pgb.write_message(&BeMessage::ErrorResponse(msg, Some(query_error_error.pg_error_code())))?;
pgb.flush().await?;
Err(io::Error::new(io::ErrorKind::ConnectionReset, msg))?;
}
Err(QueryError::Disconnected(ConnectionError::Io(io_error))) => {
Err(QueryError::Disconnected(ConnectionError::Socket(io_error))) => {
Err(io_error)?;
}
Err(other) => {
Err(io::Error::new(io::ErrorKind::Other, other.to_string()))?;
Err(io::Error::new(io::ErrorKind::Other, other))?;
}
};
}
}
}
/// Read the end of a tar archive.
///
/// A tar archive normally ends with two consecutive blocks of zeros, 512 bytes each.
/// `tokio_tar` already read the first such block. Read the second all-zeros block,
/// and check that there is no more data after the EOF marker.
///
/// XXX: Currently, any trailing data after the EOF marker prints a warning.
/// Perhaps it should be a hard error?
async fn read_tar_eof(mut reader: (impl tokio::io::AsyncRead + Unpin)) -> anyhow::Result<()> {
use tokio::io::AsyncReadExt;
let mut buf = [0u8; 512];
// Read the all-zeros block, and verify it
let mut total_bytes = 0;
while total_bytes < 512 {
let nbytes = reader.read(&mut buf[total_bytes..]).await?;
total_bytes += nbytes;
if nbytes == 0 {
break;
}
}
if total_bytes < 512 {
anyhow::bail!("incomplete or invalid tar EOF marker");
}
if !buf.iter().all(|&x| x == 0) {
anyhow::bail!("invalid tar EOF marker");
}
// Drain any data after the EOF marker
let mut trailing_bytes = 0;
loop {
let nbytes = reader.read(&mut buf).await?;
trailing_bytes += nbytes;
if nbytes == 0 {
break;
}
}
if trailing_bytes > 0 {
warn!("ignored {trailing_bytes} unexpected bytes after the tar archive");
}
Ok(())
}
///////////////////////////////////////////////////////////////////////////////
///
@@ -259,7 +212,7 @@ async fn page_service_conn_main(
// we've been requested to shut down
Ok(())
}
Err(QueryError::Disconnected(ConnectionError::Io(io_error))) => {
Err(QueryError::Disconnected(ConnectionError::Socket(io_error))) => {
if is_expected_io_error(&io_error) {
info!("Postgres client disconnected ({io_error})");
Ok(())
@@ -333,7 +286,7 @@ impl PageServerHandler {
#[instrument(skip(self, pgb, ctx))]
async fn handle_pagerequests(
&self,
pgb: &mut PostgresBackendTCP,
pgb: &mut PostgresBackend,
tenant_id: TenantId,
timeline_id: TimelineId,
ctx: RequestContext,
@@ -358,7 +311,7 @@ impl PageServerHandler {
let timeline = tenant.get_timeline(timeline_id, true)?;
// switch client to COPYBOTH
pgb.write_message_noflush(&BeMessage::CopyBothResponse)?;
pgb.write_message(&BeMessage::CopyBothResponse)?;
pgb.flush().await?;
let metrics = PageRequestMetrics::new(&tenant_id, &timeline_id);
@@ -427,7 +380,7 @@ impl PageServerHandler {
})
});
pgb.write_message_noflush(&BeMessage::CopyData(&response.serialize()))?;
pgb.write_message(&BeMessage::CopyData(&response.serialize()))?;
pgb.flush().await?;
}
Ok(())
@@ -437,7 +390,7 @@ impl PageServerHandler {
#[instrument(skip(self, pgb, ctx))]
async fn handle_import_basebackup(
&self,
pgb: &mut PostgresBackendTCP,
pgb: &mut PostgresBackend,
tenant_id: TenantId,
timeline_id: TimelineId,
base_lsn: Lsn,
@@ -463,17 +416,22 @@ impl PageServerHandler {
// Import basebackup provided via CopyData
info!("importing basebackup");
pgb.write_message_noflush(&BeMessage::CopyInResponse)?;
pgb.write_message(&BeMessage::CopyInResponse)?;
pgb.flush().await?;
let copyin_reader = StreamReader::new(copyin_stream(pgb));
tokio::pin!(copyin_reader);
let mut copyin_stream = Box::pin(copyin_stream(pgb));
timeline
.import_basebackup_from_tar(&mut copyin_reader, base_lsn, &ctx)
.import_basebackup_from_tar(&mut copyin_stream, base_lsn, &ctx)
.await?;
// Read the end of the tar archive.
read_tar_eof(copyin_reader).await?;
// Drain the rest of the Copy data
let mut bytes_after_tar = 0;
while let Some(bytes) = copyin_stream.next().await {
bytes_after_tar += bytes?.len();
}
if bytes_after_tar > 0 {
warn!("ignored {bytes_after_tar} unexpected bytes after the tar archive");
}
// TODO check checksum
// Meanwhile you can verify client-side by taking fullbackup
@@ -488,7 +446,7 @@ impl PageServerHandler {
#[instrument(skip(self, pgb, ctx))]
async fn handle_import_wal(
&self,
pgb: &mut PostgresBackendTCP,
pgb: &mut PostgresBackend,
tenant_id: TenantId,
timeline_id: TimelineId,
start_lsn: Lsn,
@@ -510,15 +468,21 @@ impl PageServerHandler {
// Import wal provided via CopyData
info!("importing wal");
pgb.write_message_noflush(&BeMessage::CopyInResponse)?;
pgb.write_message(&BeMessage::CopyInResponse)?;
pgb.flush().await?;
let copyin_reader = StreamReader::new(copyin_stream(pgb));
tokio::pin!(copyin_reader);
import_wal_from_tar(&timeline, &mut copyin_reader, start_lsn, end_lsn, &ctx).await?;
let mut copyin_stream = Box::pin(copyin_stream(pgb));
let mut reader = tokio_util::io::StreamReader::new(&mut copyin_stream);
import_wal_from_tar(&timeline, &mut reader, start_lsn, end_lsn, &ctx).await?;
info!("wal import complete");
// Read the end of the tar archive.
read_tar_eof(copyin_reader).await?;
// Drain the rest of the Copy data
let mut bytes_after_tar = 0;
while let Some(bytes) = copyin_stream.next().await {
bytes_after_tar += bytes?.len();
}
if bytes_after_tar > 0 {
warn!("ignored {bytes_after_tar} unexpected bytes after the tar archive");
}
// TODO Does it make sense to overshoot?
if timeline.get_last_record_lsn() < end_lsn {
@@ -693,7 +657,7 @@ impl PageServerHandler {
#[instrument(skip(self, pgb, ctx))]
async fn handle_basebackup_request(
&mut self,
pgb: &mut PostgresBackendTCP,
pgb: &mut PostgresBackend,
tenant_id: TenantId,
timeline_id: TimelineId,
lsn: Option<Lsn>,
@@ -714,7 +678,7 @@ impl PageServerHandler {
}
// switch client to COPYOUT
pgb.write_message_noflush(&BeMessage::CopyOutResponse)?;
pgb.write_message(&BeMessage::CopyOutResponse)?;
pgb.flush().await?;
// Send a tarball of the latest layer on the timeline
@@ -731,7 +695,7 @@ impl PageServerHandler {
.await?;
}
pgb.write_message_noflush(&BeMessage::CopyDone)?;
pgb.write_message(&BeMessage::CopyDone)?;
pgb.flush().await?;
info!("basebackup complete");
@@ -757,10 +721,10 @@ impl PageServerHandler {
}
#[async_trait::async_trait]
impl postgres_backend::Handler<tokio::net::TcpStream> for PageServerHandler {
impl postgres_backend_async::Handler for PageServerHandler {
fn check_auth_jwt(
&mut self,
_pgb: &mut PostgresBackendTCP,
_pgb: &mut PostgresBackend,
jwt_response: &[u8],
) -> Result<(), QueryError> {
// this unwrap is never triggered, because check_auth_jwt only called when auth_type is NeonJWT
@@ -788,7 +752,7 @@ impl postgres_backend::Handler<tokio::net::TcpStream> for PageServerHandler {
fn startup(
&mut self,
_pgb: &mut PostgresBackendTCP,
_pgb: &mut PostgresBackend,
_sm: &FeStartupPacket,
) -> Result<(), QueryError> {
Ok(())
@@ -796,7 +760,7 @@ impl postgres_backend::Handler<tokio::net::TcpStream> for PageServerHandler {
async fn process_query(
&mut self,
pgb: &mut PostgresBackendTCP,
pgb: &mut PostgresBackend,
query_string: &str,
) -> Result<(), QueryError> {
let ctx = self.connection_ctx.attached_child();
@@ -848,7 +812,7 @@ impl postgres_backend::Handler<tokio::net::TcpStream> for PageServerHandler {
// Check that the timeline exists
self.handle_basebackup_request(pgb, tenant_id, timeline_id, lsn, None, false, ctx)
.await?;
pgb.write_message_noflush(&BeMessage::CommandComplete(b"SELECT 1"))?;
pgb.write_message(&BeMessage::CommandComplete(b"SELECT 1"))?;
}
// return pair of prev_lsn and last_lsn
else if query_string.starts_with("get_last_record_rlsn ") {
@@ -871,15 +835,15 @@ impl postgres_backend::Handler<tokio::net::TcpStream> for PageServerHandler {
let end_of_timeline = timeline.get_last_record_rlsn();
pgb.write_message_noflush(&BeMessage::RowDescription(&[
pgb.write_message(&BeMessage::RowDescription(&[
RowDescriptor::text_col(b"prev_lsn"),
RowDescriptor::text_col(b"last_lsn"),
]))?
.write_message_noflush(&BeMessage::DataRow(&[
.write_message(&BeMessage::DataRow(&[
Some(end_of_timeline.prev.to_string().as_bytes()),
Some(end_of_timeline.last.to_string().as_bytes()),
]))?
.write_message_noflush(&BeMessage::CommandComplete(b"SELECT 1"))?;
.write_message(&BeMessage::CommandComplete(b"SELECT 1"))?;
}
// same as basebackup, but result includes relational data as well
else if query_string.starts_with("fullbackup ") {
@@ -920,7 +884,7 @@ impl postgres_backend::Handler<tokio::net::TcpStream> for PageServerHandler {
// Check that the timeline exists
self.handle_basebackup_request(pgb, tenant_id, timeline_id, lsn, prev_lsn, true, ctx)
.await?;
pgb.write_message_noflush(&BeMessage::CommandComplete(b"SELECT 1"))?;
pgb.write_message(&BeMessage::CommandComplete(b"SELECT 1"))?;
} else if query_string.starts_with("import basebackup ") {
// Import the `base` section (everything but the wal) of a basebackup.
// Assumes the tenant already exists on this pageserver.
@@ -965,10 +929,10 @@ impl postgres_backend::Handler<tokio::net::TcpStream> for PageServerHandler {
)
.await
{
Ok(()) => pgb.write_message_noflush(&BeMessage::CommandComplete(b"SELECT 1"))?,
Ok(()) => pgb.write_message(&BeMessage::CommandComplete(b"SELECT 1"))?,
Err(e) => {
error!("error importing base backup between {base_lsn} and {end_lsn}: {e:?}");
pgb.write_message_noflush(&BeMessage::ErrorResponse(
pgb.write_message(&BeMessage::ErrorResponse(
&e.to_string(),
Some(e.pg_error_code()),
))?
@@ -1001,10 +965,10 @@ impl postgres_backend::Handler<tokio::net::TcpStream> for PageServerHandler {
.handle_import_wal(pgb, tenant_id, timeline_id, start_lsn, end_lsn, ctx)
.await
{
Ok(()) => pgb.write_message_noflush(&BeMessage::CommandComplete(b"SELECT 1"))?,
Ok(()) => pgb.write_message(&BeMessage::CommandComplete(b"SELECT 1"))?,
Err(e) => {
error!("error importing WAL between {start_lsn} and {end_lsn}: {e:?}");
pgb.write_message_noflush(&BeMessage::ErrorResponse(
pgb.write_message(&BeMessage::ErrorResponse(
&e.to_string(),
Some(e.pg_error_code()),
))?
@@ -1013,7 +977,7 @@ impl postgres_backend::Handler<tokio::net::TcpStream> for PageServerHandler {
} else if query_string.to_ascii_lowercase().starts_with("set ") {
// important because psycopg2 executes "SET datestyle TO 'ISO'"
// on connect
pgb.write_message_noflush(&BeMessage::CommandComplete(b"SELECT 1"))?;
pgb.write_message(&BeMessage::CommandComplete(b"SELECT 1"))?;
} else if query_string.starts_with("show ") {
// show <tenant_id>
let (_, params_raw) = query_string.split_at("show ".len());
@@ -1029,7 +993,7 @@ impl postgres_backend::Handler<tokio::net::TcpStream> for PageServerHandler {
self.check_permission(Some(tenant_id))?;
let tenant = get_active_tenant_with_timeout(tenant_id, &ctx).await?;
pgb.write_message_noflush(&BeMessage::RowDescription(&[
pgb.write_message(&BeMessage::RowDescription(&[
RowDescriptor::int8_col(b"checkpoint_distance"),
RowDescriptor::int8_col(b"checkpoint_timeout"),
RowDescriptor::int8_col(b"compaction_target_size"),
@@ -1040,7 +1004,7 @@ impl postgres_backend::Handler<tokio::net::TcpStream> for PageServerHandler {
RowDescriptor::int8_col(b"image_creation_threshold"),
RowDescriptor::int8_col(b"pitr_interval"),
]))?
.write_message_noflush(&BeMessage::DataRow(&[
.write_message(&BeMessage::DataRow(&[
Some(tenant.get_checkpoint_distance().to_string().as_bytes()),
Some(
tenant
@@ -1063,7 +1027,7 @@ impl postgres_backend::Handler<tokio::net::TcpStream> for PageServerHandler {
Some(tenant.get_image_creation_threshold().to_string().as_bytes()),
Some(tenant.get_pitr_interval().as_secs().to_string().as_bytes()),
]))?
.write_message_noflush(&BeMessage::CommandComplete(b"SELECT 1"))?;
.write_message(&BeMessage::CommandComplete(b"SELECT 1"))?;
} else {
return Err(QueryError::Other(anyhow::anyhow!(
"unknown command {query_string}"
@@ -1091,7 +1055,7 @@ impl From<GetActiveTenantError> for QueryError {
fn from(e: GetActiveTenantError) -> Self {
match e {
GetActiveTenantError::WaitForActiveTimeout { .. } => QueryError::Disconnected(
ConnectionError::Io(io::Error::new(io::ErrorKind::TimedOut, e.to_string())),
ConnectionError::Socket(io::Error::new(io::ErrorKind::TimedOut, e.to_string())),
),
GetActiveTenantError::Other(e) => QueryError::Other(e),
}
@@ -1107,10 +1071,7 @@ async fn get_active_tenant_with_timeout(
tenant_id: TenantId,
_ctx: &RequestContext, /* require get a context to support cancellation in the future */
) -> Result<Arc<Tenant>, GetActiveTenantError> {
let tenant = match mgr::get_tenant(tenant_id, false).await {
Ok(tenant) => tenant,
Err(e) => return Err(GetActiveTenantError::Other(e.into())),
};
let tenant = mgr::get_tenant(tenant_id, false).await?;
let wait_time = Duration::from_secs(30);
match tokio::time::timeout(wait_time, tenant.wait_to_become_active()).await {
Ok(Ok(())) => Ok(tenant),

View File

@@ -12,7 +12,9 @@
//!
use anyhow::{bail, Context};
use bytes::Bytes;
use futures::FutureExt;
use futures::Stream;
use pageserver_api::models::TimelineState;
use remote_storage::DownloadError;
use remote_storage::GenericRemoteStorage;
@@ -237,13 +239,14 @@ impl UninitializedTimeline<'_> {
/// Prepares timeline data by loading it from the basebackup archive.
pub async fn import_basebackup_from_tar(
self,
copyin_read: &mut (impl tokio::io::AsyncRead + Send + Sync + Unpin),
copyin_stream: &mut (impl Stream<Item = io::Result<Bytes>> + Sync + Send + Unpin),
base_lsn: Lsn,
ctx: &RequestContext,
) -> anyhow::Result<Arc<Timeline>> {
let raw_timeline = self.raw_timeline()?;
import_datadir::import_basebackup_from_tar(raw_timeline, copyin_read, base_lsn, ctx)
let mut reader = tokio_util::io::StreamReader::new(copyin_stream);
import_datadir::import_basebackup_from_tar(raw_timeline, &mut reader, base_lsn, ctx)
.await
.context("Failed to import basebackup")?;
@@ -478,7 +481,7 @@ impl Tenant {
let dummy_timeline = self.create_timeline_data(
timeline_id,
up_to_date_metadata,
up_to_date_metadata.clone(),
ancestor.clone(),
remote_client,
)?;
@@ -503,7 +506,7 @@ impl Tenant {
let broken_timeline = self
.create_timeline_data(
timeline_id,
up_to_date_metadata,
up_to_date_metadata.clone(),
ancestor.clone(),
None,
)
@@ -1142,7 +1145,7 @@ impl Tenant {
);
self.prepare_timeline(
new_timeline_id,
&new_metadata,
new_metadata,
timeline_uninit_mark,
true,
None,
@@ -1240,8 +1243,11 @@ impl Tenant {
"Cannot run GC iteration on inactive tenant"
);
self.gc_iteration_internal(target_timeline_id, horizon, pitr, ctx)
.await
let gc_result = self
.gc_iteration_internal(target_timeline_id, horizon, pitr, ctx)
.await;
gc_result
}
/// Perform one compaction iteration.
@@ -1700,7 +1706,7 @@ impl Tenant {
fn create_timeline_data(
&self,
new_timeline_id: TimelineId,
new_metadata: &TimelineMetadata,
new_metadata: TimelineMetadata,
ancestor: Option<Arc<Timeline>>,
remote_client: Option<RemoteTimelineClient>,
) -> anyhow::Result<Arc<Timeline>> {
@@ -2160,25 +2166,13 @@ impl Tenant {
let new_timeline = self
.prepare_timeline(
dst_id,
&metadata,
metadata,
timeline_uninit_mark,
false,
Some(Arc::clone(src_timeline)),
)?
.initialize_with_lock(&mut timelines, true, true)?;
drop(timelines);
// Root timeline gets its layers during creation and uploads them along with the metadata.
// A branch timeline though, when created, can get no writes for some time, hence won't get any layers created.
// We still need to upload its metadata eagerly: if other nodes `attach` the tenant and miss this timeline, their GC
// could get incorrect information and remove more layers, than needed.
// See also https://github.com/neondatabase/neon/issues/3865
if let Some(remote_client) = new_timeline.remote_client.as_ref() {
remote_client
.schedule_index_upload_for_metadata_update(&metadata)
.context("branch initial metadata upload")?;
}
info!("branched timeline {dst_id} from {src_id} at {start_lsn}");
Ok(new_timeline)
@@ -2241,7 +2235,7 @@ impl Tenant {
pg_version,
);
let raw_timeline =
self.prepare_timeline(timeline_id, &new_metadata, timeline_uninit_mark, true, None)?;
self.prepare_timeline(timeline_id, new_metadata, timeline_uninit_mark, true, None)?;
let tenant_id = raw_timeline.owning_tenant.tenant_id;
let unfinished_timeline = raw_timeline.raw_timeline()?;
@@ -2295,7 +2289,7 @@ impl Tenant {
fn prepare_timeline(
&self,
new_timeline_id: TimelineId,
new_metadata: &TimelineMetadata,
new_metadata: TimelineMetadata,
uninit_mark: TimelineUninitMark,
init_layers: bool,
ancestor: Option<Arc<Timeline>>,
@@ -2309,7 +2303,7 @@ impl Tenant {
tenant_id,
new_timeline_id,
);
remote_client.init_upload_queue_for_empty_remote(new_metadata)?;
remote_client.init_upload_queue_for_empty_remote(&new_metadata)?;
Some(remote_client)
} else {
None
@@ -2348,12 +2342,17 @@ impl Tenant {
&self,
timeline_path: &Path,
new_timeline_id: TimelineId,
new_metadata: &TimelineMetadata,
new_metadata: TimelineMetadata,
ancestor: Option<Arc<Timeline>>,
remote_client: Option<RemoteTimelineClient>,
) -> anyhow::Result<Arc<Timeline>> {
let timeline_data = self
.create_timeline_data(new_timeline_id, new_metadata, ancestor, remote_client)
.create_timeline_data(
new_timeline_id,
new_metadata.clone(),
ancestor,
remote_client,
)
.context("Failed to create timeline data structure")?;
crashsafe::create_dir_all(timeline_path).context("Failed to create timeline directory")?;
@@ -2365,7 +2364,7 @@ impl Tenant {
self.conf,
new_timeline_id,
self.tenant_id,
new_metadata,
&new_metadata,
true,
)
.context("Failed to create timeline metadata")?;
@@ -3177,44 +3176,6 @@ mod tests {
}
*/
#[tokio::test]
async fn test_get_branchpoints_from_an_inactive_timeline() -> anyhow::Result<()> {
let (tenant, ctx) =
TenantHarness::create("test_get_branchpoints_from_an_inactive_timeline")?
.load()
.await;
let tline = tenant
.create_empty_timeline(TIMELINE_ID, Lsn(0), DEFAULT_PG_VERSION, &ctx)?
.initialize(&ctx)?;
make_some_layers(tline.as_ref(), Lsn(0x20)).await?;
tenant
.branch_timeline(&tline, NEW_TIMELINE_ID, Some(Lsn(0x40)), &ctx)
.await?;
let newtline = tenant
.get_timeline(NEW_TIMELINE_ID, true)
.expect("Should have a local timeline");
make_some_layers(newtline.as_ref(), Lsn(0x60)).await?;
tline.set_state(TimelineState::Broken);
tenant
.gc_iteration(Some(TIMELINE_ID), 0x10, Duration::ZERO, &ctx)
.await?;
assert_eq!(
newtline.get(*TEST_KEY, Lsn(0x50), &ctx).await?,
TEST_IMG(&format!("foo at {}", Lsn(0x40)))
);
let branchpoints = &tline.gc_info.read().unwrap().retain_lsns;
assert_eq!(branchpoints.len(), 1);
assert_eq!(branchpoints[0], Lsn(0x40));
Ok(())
}
#[tokio::test]
async fn test_retain_data_in_parent_which_is_needed_for_child() -> anyhow::Result<()> {
let (tenant, ctx) =

View File

@@ -51,6 +51,9 @@ where
///
/// A "cursor" for efficiently reading multiple pages from a BlockReader
///
/// A cursor caches the last accessed page, allowing for faster access if the
/// same block is accessed repeatedly.
///
/// You can access the last page with `*cursor`. 'read_blk' returns 'self', so
/// that in many cases you can use a BlockCursor as a drop-in replacement for
/// the underlying BlockReader. For example:
@@ -70,6 +73,8 @@ where
R: BlockReader,
{
reader: R,
/// last accessed page
cache: Option<(u32, R::BlockLease)>,
}
impl<R> BlockCursor<R>
@@ -77,13 +82,40 @@ where
R: BlockReader,
{
pub fn new(reader: R) -> Self {
BlockCursor { reader }
BlockCursor {
reader,
cache: None,
}
}
pub fn read_blk(&mut self, blknum: u32) -> Result<R::BlockLease, std::io::Error> {
self.reader.read_blk(blknum)
pub fn read_blk(&mut self, blknum: u32) -> Result<&Self, std::io::Error> {
// Fast return if this is the same block as before
if let Some((cached_blk, _buf)) = &self.cache {
if *cached_blk == blknum {
return Ok(self);
}
}
// Read the block from the underlying reader, and cache it
self.cache = None;
let buf = self.reader.read_blk(blknum)?;
self.cache = Some((blknum, buf));
Ok(self)
}
}
impl<R> Deref for BlockCursor<R>
where
R: BlockReader,
{
type Target = [u8; PAGE_SZ];
fn deref(&self) -> &<Self as Deref>::Target {
&self.cache.as_ref().unwrap().1
}
}
static NEXT_ID: AtomicU64 = AtomicU64::new(1);
/// An adapter for reading a (virtual) file using the page cache.

View File

@@ -222,6 +222,48 @@ impl TenantConfOpt {
eviction_policy: self.eviction_policy.unwrap_or(global_conf.eviction_policy),
}
}
pub fn update(&mut self, other: &TenantConfOpt) {
if let Some(checkpoint_distance) = other.checkpoint_distance {
self.checkpoint_distance = Some(checkpoint_distance);
}
if let Some(checkpoint_timeout) = other.checkpoint_timeout {
self.checkpoint_timeout = Some(checkpoint_timeout);
}
if let Some(compaction_target_size) = other.compaction_target_size {
self.compaction_target_size = Some(compaction_target_size);
}
if let Some(compaction_period) = other.compaction_period {
self.compaction_period = Some(compaction_period);
}
if let Some(compaction_threshold) = other.compaction_threshold {
self.compaction_threshold = Some(compaction_threshold);
}
if let Some(gc_horizon) = other.gc_horizon {
self.gc_horizon = Some(gc_horizon);
}
if let Some(gc_period) = other.gc_period {
self.gc_period = Some(gc_period);
}
if let Some(image_creation_threshold) = other.image_creation_threshold {
self.image_creation_threshold = Some(image_creation_threshold);
}
if let Some(pitr_interval) = other.pitr_interval {
self.pitr_interval = Some(pitr_interval);
}
if let Some(walreceiver_connect_timeout) = other.walreceiver_connect_timeout {
self.walreceiver_connect_timeout = Some(walreceiver_connect_timeout);
}
if let Some(lagging_wal_timeout) = other.lagging_wal_timeout {
self.lagging_wal_timeout = Some(lagging_wal_timeout);
}
if let Some(max_lsn_wal_lag) = other.max_lsn_wal_lag {
self.max_lsn_wal_lag = Some(max_lsn_wal_lag);
}
if let Some(trace_read_requests) = other.trace_read_requests {
self.trace_read_requests = Some(trace_read_requests);
}
}
}
impl Default for TenantConf {

View File

@@ -2,7 +2,9 @@
//! used to keep in-memory layers spilled on disk.
use crate::config::PageServerConf;
use crate::page_cache::{self, ReadBufResult, WriteBufResult, PAGE_SZ};
use crate::page_cache;
use crate::page_cache::PAGE_SZ;
use crate::page_cache::{ReadBufResult, WriteBufResult};
use crate::tenant::blob_io::BlobWriter;
use crate::tenant::block_io::BlockReader;
use crate::virtual_file::VirtualFile;
@@ -425,6 +427,7 @@ mod tests {
let actual = cursor.read_blob(pos)?;
assert_eq!(actual, expected);
}
drop(cursor);
// Test a large blob that spans multiple pages
let mut large_data = Vec::new();

View File

@@ -154,7 +154,11 @@ where
expected: &Arc<L>,
new: Arc<L>,
) -> anyhow::Result<Replacement<Arc<L>>> {
fail::fail_point!("layermap-replace-notfound", |_| Ok(Replacement::NotFound));
fail::fail_point!("layermap-replace-notfound", |_| Ok(
// this is not what happens if an L0 layer was not found a anyhow error but perhaps
// that should be changed. this is good enough to show a replacement failure.
Replacement::NotFound
));
self.layer_map.replace_historic_noflush(expected, new)
}
@@ -336,15 +340,12 @@ where
let l0_index = if expected_l0 {
// find the index in case replace worked, we need to replace that as well
let pos = self
.l0_delta_layers
.iter()
.position(|slot| Self::compare_arced_layers(slot, expected));
if pos.is_none() {
return Ok(Replacement::NotFound);
}
pos
Some(
self.l0_delta_layers
.iter()
.position(|slot| Self::compare_arced_layers(slot, expected))
.ok_or_else(|| anyhow::anyhow!("existing l0 delta layer was not found"))?,
)
} else {
None
};
@@ -803,26 +804,6 @@ mod tests {
)
}
#[test]
fn replacing_missing_l0_is_notfound() {
// original impl had an oversight, and L0 was an anyhow::Error. anyhow::Error should
// however only happen for precondition failures.
let layer = "000000000000000000000000000000000000-FFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFF__0000000053423C21-0000000053424D69";
let layer = LayerFileName::from_str(layer).unwrap();
let layer = LayerDescriptor::from(layer);
// same skeletan construction; see scenario below
let not_found: Arc<dyn Layer> = Arc::new(layer.clone());
let new_version: Arc<dyn Layer> = Arc::new(layer);
let mut map = LayerMap::default();
let res = map.batch_update().replace_historic(&not_found, new_version);
assert!(matches!(res, Ok(Replacement::NotFound)), "{res:?}");
}
fn l0_delta_layers_updated_scenario(layer_name: &str, expected_l0: bool) {
let name = LayerFileName::from_str(layer_name).unwrap();
let skeleton = LayerDescriptor::from(name);
@@ -832,8 +813,7 @@ mod tests {
let mut map = LayerMap::default();
// two disjoint Arcs in different lifecycle phases. even if it seems they must be the
// same layer, we use LayerMap::compare_arced_layers as the identity of layers.
// two disjoint Arcs in different lifecycle phases.
assert!(!LayerMap::compare_arced_layers(&remote, &downloaded));
let expected_in_counts = (1, usize::from(expected_l0));

View File

@@ -289,7 +289,7 @@ pub async fn set_new_tenant_config(
conf: &'static PageServerConf,
new_tenant_conf: TenantConfOpt,
tenant_id: TenantId,
) -> Result<(), TenantStateError> {
) -> anyhow::Result<()> {
info!("configuring tenant {tenant_id}");
let tenant = get_tenant(tenant_id, true).await?;
@@ -306,16 +306,16 @@ pub async fn set_new_tenant_config(
/// Gets the tenant from the in-memory data, erroring if it's absent or is not fitting to the query.
/// `active_only = true` allows to query only tenants that are ready for operations, erroring on other kinds of tenants.
pub async fn get_tenant(
tenant_id: TenantId,
active_only: bool,
) -> Result<Arc<Tenant>, TenantStateError> {
pub async fn get_tenant(tenant_id: TenantId, active_only: bool) -> anyhow::Result<Arc<Tenant>> {
let m = TENANTS.read().await;
let tenant = m
.get(&tenant_id)
.ok_or(TenantStateError::NotFound(tenant_id))?;
.with_context(|| format!("Tenant {tenant_id} not found in the local state"))?;
if active_only && !tenant.is_active() {
Err(TenantStateError::NotActive(tenant_id))
anyhow::bail!(
"Tenant {tenant_id} is not active. Current state: {:?}",
tenant.current_state()
)
} else {
Ok(Arc::clone(tenant))
}
@@ -325,56 +325,31 @@ pub async fn delete_timeline(
tenant_id: TenantId,
timeline_id: TimelineId,
ctx: &RequestContext,
) -> Result<(), TenantStateError> {
let tenant = get_tenant(tenant_id, true).await?;
tenant.delete_timeline(timeline_id, ctx).await?;
Ok(())
}
) -> anyhow::Result<()> {
match get_tenant(tenant_id, true).await {
Ok(tenant) => {
tenant.delete_timeline(timeline_id, ctx).await?;
}
Err(e) => anyhow::bail!("Cannot access tenant {tenant_id} in local tenant state: {e:?}"),
}
#[derive(Debug, thiserror::Error)]
pub enum TenantStateError {
#[error("Tenant {0} not found")]
NotFound(TenantId),
#[error("Tenant {0} is stopping")]
IsStopping(TenantId),
#[error("Tenant {0} is not active")]
NotActive(TenantId),
#[error(transparent)]
Other(#[from] anyhow::Error),
Ok(())
}
pub async fn detach_tenant(
conf: &'static PageServerConf,
tenant_id: TenantId,
detach_ignored: bool,
) -> Result<(), TenantStateError> {
let local_files_cleanup_operation = |tenant_id_to_clean| async move {
let local_tenant_directory = conf.tenant_path(&tenant_id_to_clean);
) -> anyhow::Result<()> {
remove_tenant_from_memory(tenant_id, async {
let local_tenant_directory = conf.tenant_path(&tenant_id);
fs::remove_dir_all(&local_tenant_directory)
.await
.with_context(|| {
format!("local tenant directory {local_tenant_directory:?} removal")
format!("Failed to remove local tenant directory {local_tenant_directory:?}")
})?;
Ok(())
};
let removal_result =
remove_tenant_from_memory(tenant_id, local_files_cleanup_operation(tenant_id)).await;
// Ignored tenants are not present in memory and will bail the removal from memory operation.
// Before returning the error, check for ignored tenant removal case — we only need to clean its local files then.
if detach_ignored && matches!(removal_result, Err(TenantStateError::NotFound(_))) {
let tenant_ignore_mark = conf.tenant_ignore_mark_file_path(tenant_id);
if tenant_ignore_mark.exists() {
info!("Detaching an ignored tenant");
local_files_cleanup_operation(tenant_id)
.await
.with_context(|| format!("Ignored tenant {tenant_id} local files cleanup"))?;
return Ok(());
}
}
removal_result
})
.await
}
pub async fn load_tenant(
@@ -404,7 +379,7 @@ pub async fn load_tenant(
pub async fn ignore_tenant(
conf: &'static PageServerConf,
tenant_id: TenantId,
) -> Result<(), TenantStateError> {
) -> anyhow::Result<()> {
remove_tenant_from_memory(tenant_id, async {
let ignore_mark_file = conf.tenant_ignore_mark_file_path(tenant_id);
fs::File::create(&ignore_mark_file)
@@ -514,7 +489,7 @@ where
async fn remove_tenant_from_memory<V, F>(
tenant_id: TenantId,
tenant_cleanup: F,
) -> Result<V, TenantStateError>
) -> anyhow::Result<V>
where
F: std::future::Future<Output = anyhow::Result<V>>,
{
@@ -530,9 +505,11 @@ where
| TenantState::Loading
| TenantState::Broken
| TenantState::Active => tenant.set_stopping(),
TenantState::Stopping => return Err(TenantStateError::IsStopping(tenant_id)),
TenantState::Stopping => {
anyhow::bail!("Tenant {tenant_id} is stopping already")
}
},
None => return Err(TenantStateError::NotFound(tenant_id)),
None => anyhow::bail!("Tenant not found for id {tenant_id}"),
}
}
@@ -555,15 +532,10 @@ where
Err(e) => {
let tenants_accessor = TENANTS.read().await;
match tenants_accessor.get(&tenant_id) {
Some(tenant) => {
tenant.set_broken(&e.to_string());
}
None => {
warn!("Tenant {tenant_id} got removed from memory");
return Err(TenantStateError::NotFound(tenant_id));
}
Some(tenant) => tenant.set_broken(&e.to_string()),
None => warn!("Tenant {tenant_id} got removed from memory"),
}
Err(TenantStateError::Other(e))
Err(e)
}
}
}
@@ -583,7 +555,7 @@ pub async fn immediate_gc(
let tenant = guard
.get(&tenant_id)
.map(Arc::clone)
.with_context(|| format!("tenant {tenant_id}"))
.with_context(|| format!("Tenant {tenant_id} not found"))
.map_err(ApiError::NotFound)?;
let gc_horizon = gc_req.gc_horizon.unwrap_or_else(|| tenant.get_gc_horizon());
@@ -633,7 +605,7 @@ pub async fn immediate_compact(
let tenant = guard
.get(&tenant_id)
.map(Arc::clone)
.with_context(|| format!("tenant {tenant_id}"))
.with_context(|| format!("Tenant {tenant_id} not found"))
.map_err(ApiError::NotFound)?;
let timeline = tenant

View File

@@ -210,6 +210,7 @@ pub use download::{is_temp_download_file, list_remote_timelines};
use std::sync::atomic::{AtomicU32, Ordering};
use std::sync::{Arc, Mutex};
use anyhow::ensure;
use remote_storage::{DownloadError, GenericRemoteStorage};
use std::ops::DerefMut;
use tokio::runtime::Runtime;
@@ -217,10 +218,9 @@ use tracing::{debug, info, warn};
use tracing::{info_span, Instrument};
use utils::lsn::Lsn;
use crate::metrics::{
MeasureRemoteOp, RemoteOpFileKind, RemoteOpKind, RemoteTimelineClientMetrics,
REMOTE_ONDEMAND_DOWNLOADED_BYTES, REMOTE_ONDEMAND_DOWNLOADED_LAYERS,
};
use crate::metrics::RemoteOpFileKind;
use crate::metrics::RemoteOpKind;
use crate::metrics::{MeasureRemoteOp, RemoteTimelineClientMetrics};
use crate::tenant::remote_timeline_client::index::LayerFileMetadata;
use crate::{
config::PageServerConf,
@@ -346,7 +346,7 @@ impl RemoteTimelineClient {
.layer_metadata
.values()
// If we don't have the file size for the layer, don't account for it in the metric.
.map(|ilmd| ilmd.file_size)
.map(|ilmd| ilmd.file_size.unwrap_or(0))
.sum()
} else {
0
@@ -419,9 +419,33 @@ impl RemoteTimelineClient {
.await?
};
REMOTE_ONDEMAND_DOWNLOADED_LAYERS.inc();
REMOTE_ONDEMAND_DOWNLOADED_BYTES.inc_by(downloaded_size);
// Update the metadata for given layer file. The remote index file
// might be missing some information for the file; this allows us
// to fill in the missing details.
if layer_metadata.file_size().is_none() {
let new_metadata = LayerFileMetadata::new(downloaded_size);
let mut guard = self.upload_queue.lock().unwrap();
let upload_queue = guard.initialized_mut()?;
if let Some(upgraded) = upload_queue.latest_files.get_mut(layer_file_name) {
if upgraded.merge(&new_metadata) {
upload_queue.latest_files_changes_since_metadata_upload_scheduled += 1;
}
// If we don't do an index file upload inbetween here and restart,
// the value will go back down after pageserver restart, since we will
// have lost this data point.
// But, we upload index part fairly frequently, and restart pageserver rarely.
// So, by accounting eagerly, we present a most-of-the-time-more-accurate value sooner.
self.metrics
.remote_physical_size_gauge()
.add(downloaded_size);
} else {
// The file should exist, since we just downloaded it.
warn!(
"downloaded file {:?} not found in local copy of the index file",
layer_file_name
);
}
}
Ok(downloaded_size)
}
@@ -521,6 +545,13 @@ impl RemoteTimelineClient {
let mut guard = self.upload_queue.lock().unwrap();
let upload_queue = guard.initialized_mut()?;
// The file size can be missing for files that were created before we tracked that
// in the metadata, but it should be present for any new files we create.
ensure!(
layer_metadata.file_size().is_some(),
"file size not initialized in metadata"
);
upload_queue
.latest_files
.insert(layer_file_name.clone(), layer_metadata.clone());

View File

@@ -6,13 +6,11 @@
use std::collections::HashSet;
use std::future::Future;
use std::path::Path;
use std::time::Duration;
use anyhow::{anyhow, Context};
use tokio::fs;
use tokio::io::AsyncWriteExt;
use tracing::{info, warn};
use tracing::{error, info, warn};
use crate::config::PageServerConf;
use crate::tenant::storage_layer::LayerFileName;
@@ -21,15 +19,13 @@ use remote_storage::{DownloadError, GenericRemoteStorage};
use utils::crashsafe::path_with_suffix_extension;
use utils::id::{TenantId, TimelineId};
use super::index::{IndexPart, LayerFileMetadata};
use super::index::{IndexPart, IndexPartUnclean, LayerFileMetadata};
use super::{FAILED_DOWNLOAD_RETRIES, FAILED_DOWNLOAD_WARN_THRESHOLD};
async fn fsync_path(path: impl AsRef<std::path::Path>) -> Result<(), std::io::Error> {
fs::File::open(path).await?.sync_all().await
}
static MAX_DOWNLOAD_DURATION: Duration = Duration::from_secs(120);
///
/// If 'metadata' is given, we will validate that the downloaded file's size matches that
/// in the metadata. (In the future, we might do more cross-checks, like CRC validation)
@@ -68,28 +64,22 @@ pub async fn download_layer_file<'a>(
// TODO: this doesn't use the cached fd for some reason?
let mut destination_file = fs::File::create(&temp_file_path).await.with_context(|| {
format!(
"create a destination file for layer '{}'",
"Failed to create a destination file for layer '{}'",
temp_file_path.display()
)
})
.map_err(DownloadError::Other)?;
let mut download = storage.download(&remote_path).await.with_context(|| {
format!(
"open a download stream for layer with remote storage path '{remote_path:?}'"
"Failed to open a download stream for layer with remote storage path '{remote_path:?}'"
)
})
.map_err(DownloadError::Other)?;
let bytes_amount = tokio::time::timeout(MAX_DOWNLOAD_DURATION, tokio::io::copy(&mut download.download_stream, &mut destination_file))
.await
.map_err(|e| DownloadError::Other(anyhow::anyhow!("Timed out {:?}", e)))?
.with_context(|| {
format!("Failed to download layer with remote storage path '{remote_path:?}' into file {temp_file_path:?}")
})
.map_err(DownloadError::Other)?;
let bytes_amount = tokio::io::copy(&mut download.download_stream, &mut destination_file).await.with_context(|| {
format!("Failed to download layer with remote storage path '{remote_path:?}' into file {temp_file_path:?}")
})
.map_err(DownloadError::Other)?;
Ok((destination_file, bytes_amount))
},
&format!("download {remote_path:?}"),
).await?;
@@ -113,11 +103,16 @@ pub async fn download_layer_file<'a>(
})
.map_err(DownloadError::Other)?;
let expected = layer_metadata.file_size();
if expected != bytes_amount {
return Err(DownloadError::Other(anyhow!(
"According to layer file metadata should have downloaded {expected} bytes but downloaded {bytes_amount} bytes into file {temp_file_path:?}",
)));
match layer_metadata.file_size() {
Some(expected) if expected != bytes_amount => {
return Err(DownloadError::Other(anyhow!(
"According to layer file metadata should have downloaded {expected} bytes but downloaded {bytes_amount} bytes into file '{}'",
temp_file_path.display()
)));
}
Some(_) | None => {
// matches, or upgrading from an earlier IndexPart version
}
}
// not using sync_data because it can lose file size update
@@ -256,12 +251,14 @@ pub(super) async fn download_index_part(
)
.await?;
let index_part: IndexPart = serde_json::from_slice(&index_part_bytes)
let index_part: IndexPartUnclean = serde_json::from_slice(&index_part_bytes)
.with_context(|| {
format!("Failed to deserialize index part file into file {index_part_path:?}")
})
.map_err(DownloadError::Other)?;
let index_part = index_part.remove_unclean_layer_file_names();
Ok(index_part)
}
@@ -303,7 +300,7 @@ where
}
Err(DownloadError::Other(ref err)) => {
// Operation failed FAILED_DOWNLOAD_RETRIES times. Time to give up.
warn!("{description} still failed after {attempts} retries, giving up: {err:?}");
error!("{description} still failed after {attempts} retries, giving up: {err:?}");
return result;
}
}

Some files were not shown because too many files have changed in this diff Show More