Compare commits

..

2 Commits

Author SHA1 Message Date
Konstantin Knizhnik
a7f04ace3d Make clippy happy 2022-02-16 13:26:22 +03:00
Konstantin Knizhnik
dc3f3d0ace Calculate postgres checksum for FPI stored in page server 2022-02-16 13:16:35 +03:00
129 changed files with 2995 additions and 4946 deletions

View File

@@ -1,10 +0,0 @@
[defaults]
localhost_warning = False
host_key_checking = False
timeout = 30
[ssh_connection]
ssh_args = -F ./ansible.ssh.cfg
scp_if_ssh = True
pipelining = True

View File

@@ -1,11 +0,0 @@
Host tele.zenith.tech
User admin
Port 3023
StrictHostKeyChecking no
UserKnownHostsFile /dev/null
Host * !tele.zenith.tech
User admin
StrictHostKeyChecking no
UserKnownHostsFile /dev/null
ProxyJump tele.zenith.tech

View File

@@ -1,174 +0,0 @@
- name: Upload Zenith binaries
hosts: pageservers:safekeepers
gather_facts: False
remote_user: admin
vars:
force_deploy: false
tasks:
- name: get latest version of Zenith binaries
ignore_errors: true
register: current_version_file
set_fact:
current_version: "{{ lookup('file', '.zenith_current_version') | trim }}"
tags:
- pageserver
- safekeeper
- name: set zero value for current_version
when: current_version_file is failed
set_fact:
current_version: "0"
tags:
- pageserver
- safekeeper
- name: get deployed version from content of remote file
ignore_errors: true
ansible.builtin.slurp:
src: /usr/local/.zenith_current_version
register: remote_version_file
tags:
- pageserver
- safekeeper
- name: decode remote file content
when: remote_version_file is succeeded
set_fact:
remote_version: "{{ remote_version_file['content'] | b64decode | trim }}"
tags:
- pageserver
- safekeeper
- name: set zero value for remote_version
when: remote_version_file is failed
set_fact:
remote_version: "0"
tags:
- pageserver
- safekeeper
- name: inform about versions
debug: msg="Version to deploy - {{ current_version }}, version on storage node - {{ remote_version }}"
tags:
- pageserver
- safekeeper
- name: upload and extract Zenith binaries to /usr/local
when: current_version > remote_version or force_deploy
ansible.builtin.unarchive:
owner: root
group: root
src: zenith_install.tar.gz
dest: /usr/local
become: true
tags:
- pageserver
- safekeeper
- binaries
- putbinaries
- name: Deploy pageserver
hosts: pageservers
gather_facts: False
remote_user: admin
vars:
force_deploy: false
tasks:
- name: init pageserver
when: current_version > remote_version or force_deploy
shell:
cmd: sudo -u pageserver /usr/local/bin/pageserver -c "pg_distrib_dir='/usr/local'" --init -D /storage/pageserver/data
args:
creates: "/storage/pageserver/data/tenants"
environment:
ZENITH_REPO_DIR: "/storage/pageserver/data"
LD_LIBRARY_PATH: "/usr/local/lib"
become: true
tags:
- pageserver
- name: upload systemd service definition
when: current_version > remote_version or force_deploy
ansible.builtin.template:
src: systemd/pageserver.service
dest: /etc/systemd/system/pageserver.service
owner: root
group: root
mode: '0644'
become: true
tags:
- pageserver
- name: start systemd service
when: current_version > remote_version or force_deploy
ansible.builtin.systemd:
daemon_reload: yes
name: pageserver
enabled: yes
state: restarted
become: true
tags:
- pageserver
- name: post version to console
when: (current_version > remote_version or force_deploy) and console_mgmt_base_url is defined
shell:
cmd: |
INSTANCE_ID=$(curl -s http://169.254.169.254/latest/meta-data/instance-id)
curl -sfS -d '{"version": {{ current_version }} }' -X POST {{ console_mgmt_base_url }}/api/v1/pageservers/$INSTANCE_ID
tags:
- pageserver
- name: Deploy safekeeper
hosts: safekeepers
gather_facts: False
remote_user: admin
vars:
force_deploy: false
tasks:
# in the future safekeepers should discover pageservers byself
# but currently use first pageserver that was discovered
- name: set first pageserver var for safekeepers
when: current_version > remote_version or force_deploy
set_fact:
first_pageserver: "{{ hostvars[groups['pageservers'][0]]['inventory_hostname'] }}"
tags:
- safekeeper
- name: upload systemd service definition
when: current_version > remote_version or force_deploy
ansible.builtin.template:
src: systemd/safekeeper.service
dest: /etc/systemd/system/safekeeper.service
owner: root
group: root
mode: '0644'
become: true
tags:
- safekeeper
- name: start systemd service
when: current_version > remote_version or force_deploy
ansible.builtin.systemd:
daemon_reload: yes
name: safekeeper
enabled: yes
state: restarted
become: true
tags:
- safekeeper
- name: post version to console
when: (current_version > remote_version or force_deploy) and console_mgmt_base_url is defined
shell:
cmd: |
INSTANCE_ID=$(curl -s http://169.254.169.254/latest/meta-data/instance-id)
curl -sfS -d '{"version": {{ current_version }} }' -X POST {{ console_mgmt_base_url }}/api/v1/safekeepers/$INSTANCE_ID
tags:
- safekeeper

View File

@@ -1,52 +0,0 @@
#!/bin/bash
set -e
RELEASE=${RELEASE:-false}
# look at docker hub for latest tag fo zenith docker image
if [ "${RELEASE}" = "true" ]; then
echo "search latest relase tag"
VERSION=$(curl -s https://registry.hub.docker.com/v1/repositories/zenithdb/zenith/tags |jq -r -S '.[].name' | grep release | sed 's/release-//g' | tail -1)
if [ -z "${VERSION}" ]; then
echo "no any docker tags found, exiting..."
exit 1
else
TAG="release-${VERSION}"
fi
else
echo "search latest dev tag"
VERSION=$(curl -s https://registry.hub.docker.com/v1/repositories/zenithdb/zenith/tags |jq -r -S '.[].name' | grep -v release | tail -1)
if [ -z "${VERSION}" ]; then
echo "no any docker tags found, exiting..."
exit 1
else
TAG="${VERSION}"
fi
fi
echo "found ${VERSION}"
# do initial cleanup
rm -rf zenith_install postgres_install.tar.gz zenith_install.tar.gz .zenith_current_version
mkdir zenith_install
# retrive binaries from docker image
echo "getting binaries from docker image"
docker pull --quiet zenithdb/zenith:${TAG}
ID=$(docker create zenithdb/zenith:${TAG})
docker cp ${ID}:/data/postgres_install.tar.gz .
tar -xzf postgres_install.tar.gz -C zenith_install
docker cp ${ID}:/usr/local/bin/pageserver zenith_install/bin/
docker cp ${ID}:/usr/local/bin/safekeeper zenith_install/bin/
docker cp ${ID}:/usr/local/bin/proxy zenith_install/bin/
docker cp ${ID}:/usr/local/bin/postgres zenith_install/bin/
docker rm -vf ${ID}
# store version to file (for ansible playbooks) and create binaries tarball
echo ${VERSION} > zenith_install/.zenith_current_version
echo ${VERSION} > .zenith_current_version
tar -czf zenith_install.tar.gz -C zenith_install .
# do final cleaup
rm -rf zenith_install postgres_install.tar.gz

View File

@@ -1,7 +0,0 @@
[pageservers]
zenith-1-ps-1
[safekeepers]
zenith-1-sk-1
zenith-1-sk-2
zenith-1-sk-3

View File

@@ -1,7 +0,0 @@
[pageservers]
zenith-us-stage-ps-1
[safekeepers]
zenith-us-stage-sk-1
zenith-us-stage-sk-2
zenith-us-stage-sk-3

View File

@@ -1,18 +0,0 @@
[Unit]
Description=Zenith pageserver
After=network.target auditd.service
[Service]
Type=simple
User=pageserver
Environment=RUST_BACKTRACE=1 ZENITH_REPO_DIR=/storage/pageserver LD_LIBRARY_PATH=/usr/local/lib
ExecStart=/usr/local/bin/pageserver -c "pg_distrib_dir='/usr/local'" -c "listen_pg_addr='0.0.0.0:6400'" -c "listen_http_addr='0.0.0.0:9898'" -D /storage/pageserver/data
ExecReload=/bin/kill -HUP $MAINPID
KillMode=mixed
KillSignal=SIGINT
Restart=on-failure
TimeoutSec=10
LimitNOFILE=30000000
[Install]
WantedBy=multi-user.target

View File

@@ -1,18 +0,0 @@
[Unit]
Description=Zenith safekeeper
After=network.target auditd.service
[Service]
Type=simple
User=safekeeper
Environment=RUST_BACKTRACE=1 ZENITH_REPO_DIR=/storage/safekeeper/data LD_LIBRARY_PATH=/usr/local/lib
ExecStart=/usr/local/bin/safekeeper -l {{ inventory_hostname }}.local:6500 --listen-http {{ inventory_hostname }}.local:7676 -p {{ first_pageserver }}:6400 -D /storage/safekeeper/data
ExecReload=/bin/kill -HUP $MAINPID
KillMode=mixed
KillSignal=SIGINT
Restart=on-failure
TimeoutSec=10
LimitNOFILE=30000000
[Install]
WantedBy=multi-user.target

View File

@@ -471,78 +471,46 @@ jobs:
docker build -t zenithdb/compute-node:latest vendor/postgres && docker push zenithdb/compute-node:latest
docker tag zenithdb/compute-node:latest zenithdb/compute-node:${DOCKER_TAG} && docker push zenithdb/compute-node:${DOCKER_TAG}
# Build production zenithdb/zenith:release image and push it to Docker hub
docker-image-release:
docker:
- image: cimg/base:2021.04
steps:
- checkout
- setup_remote_docker:
docker_layer_caching: true
- run:
name: Init postgres submodule
command: git submodule update --init --depth 1
- run:
name: Build and push Docker image
command: |
echo $DOCKER_PWD | docker login -u $DOCKER_LOGIN --password-stdin
DOCKER_TAG="release-$(git log --oneline|wc -l)"
docker build --build-arg GIT_VERSION=$CIRCLE_SHA1 -t zenithdb/zenith:release . && docker push zenithdb/zenith:release
docker tag zenithdb/zenith:release zenithdb/zenith:${DOCKER_TAG} && docker push zenithdb/zenith:${DOCKER_TAG}
# Build production zenithdb/compute-node:release image and push it to Docker hub
docker-image-compute-release:
docker:
- image: cimg/base:2021.04
steps:
- checkout
- setup_remote_docker:
docker_layer_caching: true
# Build zenithdb/compute-tools:release image and push it to Docker hub
# TODO: this should probably also use versioned tag, not just :latest.
# XXX: but should it? We build and use it only locally now.
- run:
name: Build and push compute-tools Docker image
command: |
echo $DOCKER_PWD | docker login -u $DOCKER_LOGIN --password-stdin
docker build -t zenithdb/compute-tools:release -f Dockerfile.compute-tools .
docker push zenithdb/compute-tools:release
- run:
name: Init postgres submodule
command: git submodule update --init --depth 1
- run:
name: Build and push compute-node Docker image
command: |
echo $DOCKER_PWD | docker login -u $DOCKER_LOGIN --password-stdin
DOCKER_TAG="release-$(git log --oneline|wc -l)"
docker build -t zenithdb/compute-node:release vendor/postgres && docker push zenithdb/compute-node:release
docker tag zenithdb/compute-node:release zenithdb/compute-node:${DOCKER_TAG} && docker push zenithdb/compute-node:${DOCKER_TAG}
deploy-staging:
docker:
- image: cimg/python:3.10
steps:
- checkout
- setup_remote_docker
- run:
name: Get Zenith binaries
command: |
rm -rf zenith_install postgres_install.tar.gz zenith_install.tar.gz
mkdir zenith_install
DOCKER_TAG=$(git log --oneline|wc -l)
docker pull --quiet zenithdb/zenith:${DOCKER_TAG}
ID=$(docker create zenithdb/zenith:${DOCKER_TAG})
docker cp $ID:/data/postgres_install.tar.gz .
tar -xzf postgres_install.tar.gz -C zenith_install && rm postgres_install.tar.gz
docker cp $ID:/usr/local/bin/pageserver zenith_install/bin/
docker cp $ID:/usr/local/bin/safekeeper zenith_install/bin/
docker cp $ID:/usr/local/bin/proxy zenith_install/bin/
docker cp $ID:/usr/local/bin/postgres zenith_install/bin/
docker rm -v $ID
echo ${DOCKER_TAG} | tee zenith_install/.zenith_current_version
tar -czf zenith_install.tar.gz -C zenith_install .
ls -la zenith_install.tar.gz
- run:
name: Setup ansible
command: |
pip install --progress-bar off --user ansible boto3
ansible-galaxy collection install amazon.aws
- run:
name: Redeploy
name: Apply re-deploy playbook
environment:
ANSIBLE_HOST_KEY_CHECKING: false
command: |
cd "$(pwd)/.circleci/ansible"
./get_binaries.sh
echo "${TELEPORT_SSH_KEY}" | tr -d '\n'| base64 --decode >ssh-key
echo "${TELEPORT_SSH_CERT}" | tr -d '\n'| base64 --decode >ssh-key-cert.pub
chmod 0600 ssh-key
ssh-add ssh-key
rm -f ssh-key ssh-key-cert.pub
ansible-playbook deploy.yaml -i staging.hosts
rm -f zenith_install.tar.gz .zenith_current_version
echo "${STAGING_SSH_KEY}" | base64 --decode | ssh-add -
export AWS_REGION=${STAGING_AWS_REGION}
export AWS_ACCESS_KEY_ID=${STAGING_AWS_ACCESS_KEY_ID}
export AWS_SECRET_ACCESS_KEY=${STAGING_AWS_SECRET_ACCESS_KEY}
ansible-playbook .circleci/storage-redeploy.playbook.yml
rm -f zenith_install.tar.gz
deploy-staging-proxy:
docker:
@@ -565,57 +533,7 @@ jobs:
name: Re-deploy proxy
command: |
DOCKER_TAG=$(git log --oneline|wc -l)
helm upgrade zenith-proxy zenithdb/zenith-proxy --install -f .circleci/helm-values/staging.proxy.yaml --set image.tag=${DOCKER_TAG} --wait
deploy-release:
docker:
- image: cimg/python:3.10
steps:
- checkout
- setup_remote_docker
- run:
name: Setup ansible
command: |
pip install --progress-bar off --user ansible boto3
- run:
name: Redeploy
command: |
cd "$(pwd)/.circleci/ansible"
RELEASE=true ./get_binaries.sh
echo "${TELEPORT_SSH_KEY}" | tr -d '\n'| base64 --decode >ssh-key
echo "${TELEPORT_SSH_CERT}" | tr -d '\n'| base64 --decode >ssh-key-cert.pub
chmod 0600 ssh-key
ssh-add ssh-key
rm -f ssh-key ssh-key-cert.pub
ansible-playbook deploy.yaml -i production.hosts -e console_mgmt_base_url=http://console-release.local
rm -f zenith_install.tar.gz .zenith_current_version
deploy-release-proxy:
docker:
- image: cimg/base:2021.04
environment:
KUBECONFIG: .kubeconfig
steps:
- checkout
- run:
name: Store kubeconfig file
command: |
echo "${PRODUCTION_KUBECONFIG_DATA}" | base64 --decode > ${KUBECONFIG}
chmod 0600 ${KUBECONFIG}
- run:
name: Setup helm v3
command: |
curl -s https://raw.githubusercontent.com/helm/helm/main/scripts/get-helm-3 | bash
helm repo add zenithdb https://zenithdb.github.io/helm-charts
- run:
name: Re-deploy proxy
command: |
DOCKER_TAG="release-$(git log --oneline|wc -l)"
helm upgrade zenith-proxy zenithdb/zenith-proxy --install -f .circleci/helm-values/production.proxy.yaml --set image.tag=${DOCKER_TAG} --wait
helm upgrade zenith-proxy zenithdb/zenith-proxy --install -f .circleci/proxy.staging.yaml --set image.tag=${DOCKER_TAG} --wait
# Trigger a new remote CI job
remote-ci-trigger:
@@ -751,47 +669,6 @@ workflows:
- main
requires:
- docker-image
- docker-image-release:
# Context gives an ability to login
context: Docker Hub
# Build image only for commits to main
filters:
branches:
only:
- release
requires:
- pg_regress-tests-release
- other-tests-release
- docker-image-compute-release:
# Context gives an ability to login
context: Docker Hub
# Build image only for commits to main
filters:
branches:
only:
- release
requires:
- pg_regress-tests-release
- other-tests-release
- deploy-release:
# Context gives an ability to login
context: Docker Hub
# deploy only for commits to main
filters:
branches:
only:
- release
requires:
- docker-image-release
- deploy-release-proxy:
# deploy only for commits to main
filters:
branches:
only:
- release
requires:
- docker-image-release
- remote-ci-trigger:
# Context passes credentials for gh api
context: CI_ACCESS_TOKEN

View File

@@ -1,35 +0,0 @@
# Helm chart values for zenith-proxy.
# This is a YAML-formatted file.
settings:
authEndpoint: "https://console.zenith.tech/authenticate_proxy_request/"
uri: "https://console.zenith.tech/psql_session/"
# -- Additional labels for zenith-proxy pods
podLabels:
zenith_service: proxy
zenith_env: production
zenith_region: us-west-2
zenith_region_slug: oregon
service:
annotations:
service.beta.kubernetes.io/aws-load-balancer-type: external
service.beta.kubernetes.io/aws-load-balancer-nlb-target-type: ip
service.beta.kubernetes.io/aws-load-balancer-scheme: internal
external-dns.alpha.kubernetes.io/hostname: proxy-release.local
type: LoadBalancer
exposedService:
annotations:
service.beta.kubernetes.io/aws-load-balancer-type: external
service.beta.kubernetes.io/aws-load-balancer-nlb-target-type: ip
service.beta.kubernetes.io/aws-load-balancer-scheme: internet-facing
external-dns.alpha.kubernetes.io/hostname: start.zenith.tech
metrics:
enabled: true
serviceMonitor:
enabled: true
selector:
release: kube-prometheus-stack

View File

@@ -0,0 +1,138 @@
- name: discover storage nodes
hosts: localhost
connection: local
gather_facts: False
tasks:
- name: discover safekeepers
no_log: true
ec2_instance_info:
filters:
"tag:zenith_env": "staging"
"tag:zenith_service": "safekeeper"
register: ec2_safekeepers
- name: discover pageservers
no_log: true
ec2_instance_info:
filters:
"tag:zenith_env": "staging"
"tag:zenith_service": "pageserver"
register: ec2_pageservers
- name: add safekeepers to host group
no_log: true
add_host:
name: safekeeper-{{ ansible_loop.index }}
ansible_host: "{{ item.public_ip_address }}"
groups:
- storage
- safekeepers
with_items: "{{ ec2_safekeepers.instances }}"
loop_control:
extended: yes
- name: add pageservers to host group
no_log: true
add_host:
name: pageserver-{{ ansible_loop.index }}
ansible_host: "{{ item.public_ip_address }}"
groups:
- storage
- pageservers
with_items: "{{ ec2_pageservers.instances }}"
loop_control:
extended: yes
- name: Retrive versions
hosts: storage
gather_facts: False
remote_user: admin
tasks:
- name: Get current version of binaries
set_fact:
current_version: "{{lookup('file', '../zenith_install/.zenith_current_version') }}"
- name: Check that file with version exists on host
stat:
path: /usr/local/.zenith_current_version
register: version_file
- name: Try to get current version from the host
when: version_file.stat.exists
ansible.builtin.fetch:
src: /usr/local/.zenith_current_version
dest: .remote_version.{{ inventory_hostname }}
fail_on_missing: no
flat: yes
- name: Store remote version to variable
when: version_file.stat.exists
set_fact:
remote_version: "{{ lookup('file', '.remote_version.{{ inventory_hostname }}') }}"
- name: Store default value of remote version to variable in case when remote version file not found
when: not version_file.stat.exists
set_fact:
remote_version: "000"
- name: Extract Zenith binaries
hosts: storage
gather_facts: False
remote_user: admin
tasks:
- name: Inform about version conflict
when: current_version <= remote_version
debug: msg="Current version {{ current_version }} LE than remote {{ remote_version }}"
- name: Extract Zenith binaries to /usr/local
when: current_version > remote_version
ansible.builtin.unarchive:
src: ../zenith_install.tar.gz
dest: /usr/local
become: true
- name: Restart safekeepers
hosts: safekeepers
gather_facts: False
remote_user: admin
tasks:
- name: Inform about version conflict
when: current_version <= remote_version
debug: msg="Current version {{ current_version }} LE than remote {{ remote_version }}"
- name: Restart systemd service
when: current_version > remote_version
ansible.builtin.systemd:
daemon_reload: yes
name: safekeeper
enabled: yes
state: restarted
become: true
- name: Restart pageservers
hosts: pageservers
gather_facts: False
remote_user: admin
tasks:
- name: Inform about version conflict
when: current_version <= remote_version
debug: msg="Current version {{ current_version }} LE than remote {{ remote_version }}"
- name: Restart systemd service
when: current_version > remote_version
ansible.builtin.systemd:
daemon_reload: yes
name: pageserver
enabled: yes
state: restarted
become: true

98
Cargo.lock generated
View File

@@ -23,17 +23,6 @@ version = "0.4.7"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "739f4a8db6605981345c5654f3a85b056ce52f37a39d34da03f25bf2151ea16e"
[[package]]
name = "ahash"
version = "0.7.6"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "fcb51a0695d8f838b1ee009b3fbf66bda078cd64590202a864a8f3e8c4315c47"
dependencies = [
"getrandom",
"once_cell",
"version_check",
]
[[package]]
name = "aho-corasick"
version = "0.7.18"
@@ -552,17 +541,6 @@ dependencies = [
"termcolor",
]
[[package]]
name = "fail"
version = "0.5.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "ec3245a0ca564e7f3c797d20d833a6870f57a728ac967d5225b3ffdef4465011"
dependencies = [
"lazy_static",
"log",
"rand",
]
[[package]]
name = "fallible-iterator"
version = "0.2.0"
@@ -791,7 +769,7 @@ version = "0.9.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d7afe4a420e3fe79967a00898cc1f4db7c8a49a9333a29f8a4bd76a253d5cd04"
dependencies = [
"ahash 0.4.7",
"ahash",
]
[[package]]
@@ -799,9 +777,6 @@ name = "hashbrown"
version = "0.11.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "ab5ef0d4909ef3724cc8cce6ccc8572c5c817592e9285f5464f8e86f8bd3726e"
dependencies = [
"ahash 0.7.6",
]
[[package]]
name = "hermit-abi"
@@ -921,7 +896,7 @@ dependencies = [
"hyper",
"rustls 0.20.2",
"tokio",
"tokio-rustls 0.23.2",
"tokio-rustls",
]
[[package]]
@@ -1006,7 +981,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "afabcc15e437a6484fc4f12d0fd63068fe457bf93f1c148d3d9649c60b103f32"
dependencies = [
"base64 0.12.3",
"pem 0.8.3",
"pem",
"ring",
"serde",
"serde_json",
@@ -1300,7 +1275,6 @@ dependencies = [
"crc32c",
"crossbeam-utils",
"daemonize",
"fail",
"futures",
"hex",
"hex-literal",
@@ -1378,15 +1352,6 @@ dependencies = [
"regex",
]
[[package]]
name = "pem"
version = "1.0.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e9a3b09a20e374558580a4914d3b7d89bd61b954a5a5e1dcbea98753addb1947"
dependencies = [
"base64 0.13.0",
]
[[package]]
name = "percent-encoding"
version = "2.1.0"
@@ -1591,25 +1556,17 @@ dependencies = [
"anyhow",
"bytes",
"clap 3.0.14",
"futures",
"hashbrown 0.11.2",
"hex",
"hyper",
"lazy_static",
"md5",
"parking_lot",
"pin-project-lite",
"rand",
"rcgen",
"reqwest",
"rustls 0.19.1",
"scopeguard",
"serde",
"serde_json",
"tokio",
"tokio-postgres 0.7.1 (git+https://github.com/zenithdb/rust-postgres.git?rev=2949d98df52587d562986aad155dd4e889e408b7)",
"tokio-postgres-rustls",
"tokio-rustls 0.22.0",
"zenith_metrics",
"zenith_utils",
]
@@ -1663,18 +1620,6 @@ dependencies = [
"rand_core",
]
[[package]]
name = "rcgen"
version = "0.8.14"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "5911d1403f4143c9d56a702069d593e8d0f3fab880a85e103604d0893ea31ba7"
dependencies = [
"chrono",
"pem 1.0.2",
"ring",
"yasna",
]
[[package]]
name = "redox_syscall"
version = "0.2.10"
@@ -1758,7 +1703,7 @@ dependencies = [
"serde_json",
"serde_urlencoded",
"tokio",
"tokio-rustls 0.23.2",
"tokio-rustls",
"tokio-util",
"url",
"wasm-bindgen",
@@ -2320,32 +2265,6 @@ dependencies = [
"tokio-util",
]
[[package]]
name = "tokio-postgres-rustls"
version = "0.8.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "7bd8c37d8c23cb6ecdc32fc171bade4e9c7f1be65f693a17afbaad02091a0a19"
dependencies = [
"futures",
"ring",
"rustls 0.19.1",
"tokio",
"tokio-postgres 0.7.1 (git+https://github.com/zenithdb/rust-postgres.git?rev=2949d98df52587d562986aad155dd4e889e408b7)",
"tokio-rustls 0.22.0",
"webpki 0.21.4",
]
[[package]]
name = "tokio-rustls"
version = "0.22.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "bc6844de72e57df1980054b38be3a9f4702aba4858be64dd700181a8a6d0e1b6"
dependencies = [
"rustls 0.19.1",
"tokio",
"webpki 0.21.4",
]
[[package]]
name = "tokio-rustls"
version = "0.23.2"
@@ -2811,15 +2730,6 @@ version = "0.8.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d2d7d3948613f75c98fd9328cfdcc45acc4d360655289d0a7d4ec931392200a3"
[[package]]
name = "yasna"
version = "0.4.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e262a29d0e61ccf2b6190d7050d4b237535fc76ce4c1210d9caa316f71dffa75"
dependencies = [
"chrono",
]
[[package]]
name = "zenith"
version = "0.1.0"

View File

@@ -16,8 +16,3 @@ members = [
# This is useful for profiling and, to some extent, debug.
# Besides, debug info should not affect the performance.
debug = true
# This is only needed for proxy's tests
# TODO: we should probably fork tokio-postgres-rustls instead
[patch.crates-io]
tokio-postgres = { git = "https://github.com/zenithdb/rust-postgres.git", rev="2949d98df52587d562986aad155dd4e889e408b7" }

View File

@@ -6,7 +6,7 @@
# Build Postgres separately --- this layer will be rebuilt only if one of
# mentioned paths will get any changes.
#
FROM zimg/rust:1.56 AS pg-build
FROM zenithdb/build:buster AS pg-build
WORKDIR /zenith
COPY ./vendor/postgres vendor/postgres
COPY ./Makefile Makefile
@@ -20,7 +20,7 @@ RUN rm -rf postgres_install/build
# TODO: build cargo deps as separate layer. We used cargo-chef before but that was
# net time waste in a lot of cases. Copying Cargo.lock with empty lib.rs should do the work.
#
FROM zimg/rust:1.56 AS build
FROM zenithdb/build:buster AS build
ARG GIT_VERSION
RUN if [ -z "$GIT_VERSION" ]; then echo "GIT_VERSION is reqired, use build_arg to pass it"; exit 1; fi
@@ -34,7 +34,7 @@ RUN GIT_VERSION=$GIT_VERSION cargo build --release
#
# Copy binaries to resulting image.
#
FROM debian:bullseye-slim
FROM debian:buster-slim
WORKDIR /data
RUN apt-get update && apt-get -yq install libreadline-dev libseccomp-dev openssl ca-certificates && \

16
Dockerfile.build Normal file
View File

@@ -0,0 +1,16 @@
#
# Image with all the required dependencies to build https://github.com/zenithdb/zenith
# and Postgres from https://github.com/zenithdb/postgres
# Also includes some rust development and build tools.
# NB: keep in sync with rust image version in .circle/config.yml
#
FROM rust:1.56.1-slim-buster
WORKDIR /zenith
# Install postgres and zenith build dependencies
# clang is for rocksdb
RUN apt-get update && apt-get -yq install automake libtool build-essential bison flex libreadline-dev zlib1g-dev libxml2-dev \
libseccomp-dev pkg-config libssl-dev clang
# Install rust tools
RUN rustup component add clippy && cargo install cargo-audit

View File

@@ -171,7 +171,7 @@ impl PgQuote for PgIdent {
/// always quotes provided string with `""` and escapes every `"`. Not idempotent,
/// i.e. if string is already escaped it will be escaped again.
fn quote(&self) -> String {
let result = format!("\"{}\"", self.replace('"', "\"\""));
let result = format!("\"{}\"", self.replace("\"", "\"\""));
result
}
}

View File

@@ -215,7 +215,7 @@ pub fn handle_databases(spec: &ClusterSpec, client: &mut Client) -> Result<()> {
if let Some(r) = pg_db {
// XXX: db owner name is returned as quoted string from Postgres,
// when quoting is needed.
let new_owner = if r.owner.starts_with('"') {
let new_owner = if r.owner.starts_with('\"') {
db.owner.quote()
} else {
db.owner.clone()

View File

@@ -5,16 +5,16 @@ listen_http_addr = '127.0.0.1:9898'
auth_type = 'Trust'
[[safekeepers]]
id = 1
name = 'sk1'
pg_port = 5454
http_port = 7676
[[safekeepers]]
id = 2
name = 'sk2'
pg_port = 5455
http_port = 7677
[[safekeepers]]
id = 3
name = 'sk3'
pg_port = 5456
http_port = 7678

View File

@@ -6,6 +6,6 @@ listen_http_addr = '127.0.0.1:9898'
auth_type = 'Trust'
[[safekeepers]]
id = 1
name = 'single'
pg_port = 5454
http_port = 7676

View File

@@ -334,26 +334,14 @@ impl PostgresNode {
if let Some(lsn) = self.lsn {
conf.append("recovery_target_lsn", &lsn.to_string());
}
conf.append_line("");
// Configure backpressure
// - Replication write lag depends on how fast the walreceiver can process incoming WAL.
// This lag determines latency of get_page_at_lsn. Speed of applying WAL is about 10MB/sec,
// so to avoid expiration of 1 minute timeout, this lag should not be larger than 600MB.
// Actually latency should be much smaller (better if < 1sec). But we assume that recently
// updates pages are not requested from pageserver.
// - Replication flush lag depends on speed of persisting data by checkpointer (creation of
// delta/image layers) and advancing disk_consistent_lsn. Safekeepers are able to
// remove/archive WAL only beyond disk_consistent_lsn. Too large a lag can cause long
// recovery time (in case of pageserver crash) and disk space overflow at safekeepers.
// - Replication apply lag depends on speed of uploading changes to S3 by uploader thread.
// To be able to restore database in case of pageserver node crash, safekeeper should not
// remove WAL beyond this point. Too large lag can cause space exhaustion in safekeepers
// (if they are not able to upload WAL to S3).
conf.append("max_replication_write_lag", "500MB");
conf.append("max_replication_flush_lag", "10GB");
if !self.env.safekeepers.is_empty() {
// Configure backpressure
// In setup with safekeepers apply_lag depends on
// speed of data checkpointing on pageserver (see disk_consistent_lsn).
conf.append("max_replication_apply_lag", "1500MB");
// Configure the node to connect to the safekeepers
conf.append("synchronous_standby_names", "walproposer");
@@ -366,6 +354,11 @@ impl PostgresNode {
.join(",");
conf.append("wal_acceptors", &wal_acceptors);
} else {
// Configure backpressure
// In setup without safekeepers, flush_lag depends on
// speed of of data checkpointing on pageserver (see disk_consistent_lsn)
conf.append("max_replication_flush_lag", "1500MB");
// We only use setup without safekeepers for tests,
// and don't care about data durability on pageserver,
// so set more relaxed synchronous_commit.

View File

@@ -12,9 +12,7 @@ use std::path::{Path, PathBuf};
use std::process::{Command, Stdio};
use zenith_utils::auth::{encode_from_key_file, Claims, Scope};
use zenith_utils::postgres_backend::AuthType;
use zenith_utils::zid::{HexZTenantId, ZNodeId, ZTenantId};
use crate::safekeeper::SafekeeperNode;
use zenith_utils::zid::{opt_display_serde, ZTenantId};
//
// This data structures represents zenith CLI config
@@ -47,8 +45,9 @@ pub struct LocalEnv {
// Default tenant ID to use with the 'zenith' command line utility, when
// --tenantid is not explicitly specified.
#[serde(with = "opt_display_serde")]
#[serde(default)]
pub default_tenantid: Option<HexZTenantId>,
pub default_tenantid: Option<ZTenantId>,
// used to issue tokens during e.g pg start
#[serde(default)]
@@ -63,8 +62,6 @@ pub struct LocalEnv {
#[derive(Serialize, Deserialize, Clone, Debug)]
#[serde(default)]
pub struct PageServerConf {
// node id
pub id: ZNodeId,
// Pageserver connection settings
pub listen_pg_addr: String,
pub listen_http_addr: String,
@@ -79,7 +76,6 @@ pub struct PageServerConf {
impl Default for PageServerConf {
fn default() -> Self {
Self {
id: ZNodeId(0),
listen_pg_addr: String::new(),
listen_http_addr: String::new(),
auth_type: AuthType::Trust,
@@ -91,7 +87,7 @@ impl Default for PageServerConf {
#[derive(Serialize, Deserialize, Clone, Debug)]
#[serde(default)]
pub struct SafekeeperConf {
pub id: ZNodeId,
pub name: String,
pub pg_port: u16,
pub http_port: u16,
pub sync: bool,
@@ -100,7 +96,7 @@ pub struct SafekeeperConf {
impl Default for SafekeeperConf {
fn default() -> Self {
Self {
id: ZNodeId(0),
name: String::new(),
pg_port: 0,
http_port: 0,
sync: true,
@@ -140,8 +136,8 @@ impl LocalEnv {
self.base_data_dir.clone()
}
pub fn safekeeper_data_dir(&self, data_dir_name: &str) -> PathBuf {
self.base_data_dir.join("safekeepers").join(data_dir_name)
pub fn safekeeper_data_dir(&self, node_name: &str) -> PathBuf {
self.base_data_dir.join("safekeepers").join(node_name)
}
/// Create a LocalEnv from a config file.
@@ -184,7 +180,7 @@ impl LocalEnv {
// If no initial tenant ID was given, generate it.
if env.default_tenantid.is_none() {
env.default_tenantid = Some(HexZTenantId::from(ZTenantId::generate()));
env.default_tenantid = Some(ZTenantId::generate());
}
env.base_data_dir = base_path();
@@ -289,7 +285,7 @@ impl LocalEnv {
fs::create_dir_all(self.pg_data_dirs_path())?;
for safekeeper in &self.safekeepers {
fs::create_dir_all(SafekeeperNode::datadir_path_by_id(self, safekeeper.id))?;
fs::create_dir_all(self.safekeeper_data_dir(&safekeeper.name))?;
}
let mut conf_content = String::new();

View File

@@ -15,7 +15,6 @@ use reqwest::blocking::{Client, RequestBuilder, Response};
use reqwest::{IntoUrl, Method};
use thiserror::Error;
use zenith_utils::http::error::HttpErrorBody;
use zenith_utils::zid::ZNodeId;
use crate::local_env::{LocalEnv, SafekeeperConf};
use crate::storage::PageServerNode;
@@ -62,7 +61,7 @@ impl ResponseErrorMessageExt for Response {
//
#[derive(Debug)]
pub struct SafekeeperNode {
pub id: ZNodeId,
pub name: String,
pub conf: SafekeeperConf,
@@ -78,10 +77,10 @@ impl SafekeeperNode {
pub fn from_env(env: &LocalEnv, conf: &SafekeeperConf) -> SafekeeperNode {
let pageserver = Arc::new(PageServerNode::from_env(env));
println!("initializing for sk {} for {}", conf.id, conf.http_port);
println!("initializing for {} for {}", conf.name, conf.http_port);
SafekeeperNode {
id: conf.id,
name: conf.name.clone(),
conf: conf.clone(),
pg_connection_config: Self::safekeeper_connection_config(conf.pg_port),
env: env.clone(),
@@ -99,12 +98,8 @@ impl SafekeeperNode {
.unwrap()
}
pub fn datadir_path_by_id(env: &LocalEnv, sk_id: ZNodeId) -> PathBuf {
env.safekeeper_data_dir(format!("sk{}", sk_id).as_ref())
}
pub fn datadir_path(&self) -> PathBuf {
SafekeeperNode::datadir_path_by_id(&self.env, self.id)
self.env.safekeeper_data_dir(&self.name)
}
pub fn pid_file(&self) -> PathBuf {
@@ -125,7 +120,6 @@ impl SafekeeperNode {
let mut cmd = Command::new(self.env.safekeeper_bin()?);
fill_rust_env_vars(
cmd.args(&["-D", self.datadir_path().to_str().unwrap()])
.args(&["--id", self.id.to_string().as_ref()])
.args(&["--listen-pg", &listen_pg])
.args(&["--listen-http", &listen_http])
.args(&["--recall", "1 second"])
@@ -189,7 +183,7 @@ impl SafekeeperNode {
pub fn stop(&self, immediate: bool) -> anyhow::Result<()> {
let pid_file = self.pid_file();
if !pid_file.exists() {
println!("Safekeeper {} is already stopped", self.id);
println!("Safekeeper {} is already stopped", self.name);
return Ok(());
}
let pid = read_pidfile(&pid_file)?;

View File

@@ -103,8 +103,6 @@ impl PageServerNode {
) -> anyhow::Result<()> {
let mut cmd = Command::new(self.env.pageserver_bin()?);
let id = format!("id={}", self.env.pageserver.id);
// FIXME: the paths should be shell-escaped to handle paths with spaces, quotas etc.
let base_data_dir_param = self.env.base_data_dir.display().to_string();
let pg_distrib_dir_param =
@@ -124,7 +122,6 @@ impl PageServerNode {
args.extend(["-c", &authg_type_param]);
args.extend(["-c", &listen_http_addr_param]);
args.extend(["-c", &listen_pg_addr_param]);
args.extend(["-c", &id]);
for config_override in config_overrides {
args.extend(["-c", config_override]);

View File

@@ -4,7 +4,7 @@ set -eux
if [ "$1" = 'pageserver' ]; then
if [ ! -d "/data/tenants" ]; then
echo "Initializing pageserver data directory"
pageserver --init -D /data -c "pg_distrib_dir='/usr/local'" -c "id=10"
pageserver --init -D /data -c "pg_distrib_dir='/usr/local'"
fi
echo "Staring pageserver at 0.0.0.0:6400"
pageserver -c "listen_pg_addr='0.0.0.0:6400'" -c "listen_http_addr='0.0.0.0:9898'" -D /data

View File

@@ -7,14 +7,32 @@ Currently we build two main images:
- [zenithdb/zenith](https://hub.docker.com/repository/docker/zenithdb/zenith) — image with pre-built `pageserver`, `safekeeper` and `proxy` binaries and all the required runtime dependencies. Built from [/Dockerfile](/Dockerfile).
- [zenithdb/compute-node](https://hub.docker.com/repository/docker/zenithdb/compute-node) — compute node image with pre-built Postgres binaries from [zenithdb/postgres](https://github.com/zenithdb/postgres).
And additional intermediate images:
And two intermediate images used either to reduce build time or to deliver some additional binary tools from other repos:
- [zenithdb/build](https://hub.docker.com/repository/docker/zenithdb/build) — image with all the dependencies required to build Zenith and compute node images. This image is based on `rust:slim-buster`, so it also has a proper `rust` environment. Built from [/Dockerfile.build](/Dockerfile.build).
- [zenithdb/compute-tools](https://hub.docker.com/repository/docker/zenithdb/compute-tools) — compute node configuration management tools.
## Building pipeline
1. Image `zenithdb/compute-tools` is re-built automatically.
2. Image `zenithdb/compute-node` is built independently in the [zenithdb/postgres](https://github.com/zenithdb/postgres) repo.
2. Image `zenithdb/build` is built manually. If you want to introduce any new compile time dependencies to Zenith or compute node you have to update this image as well, build it and push to Docker Hub.
3. Image `zenithdb/zenith` is built in this repo after a successful `release` tests run and pushed to Docker Hub automatically.
Build:
```sh
docker build -t zenithdb/build:buster -f Dockerfile.build .
```
Login:
```sh
docker login
```
Push to Docker Hub:
```sh
docker push zenithdb/build:buster
```
3. Image `zenithdb/compute-node` is built independently in the [zenithdb/postgres](https://github.com/zenithdb/postgres) repo.
4. Image `zenithdb/zenith` is built in this repo after a successful `release` tests run and pushed to Docker Hub automatically.

View File

@@ -41,7 +41,6 @@ url = "2"
nix = "0.23"
once_cell = "1.8.0"
crossbeam-utils = "0.8.5"
fail = "0.5.0"
rust-s3 = { version = "0.28", default-features = false, features = ["no-verify-ssl", "tokio-rustls-tls"] }
async-compression = {version = "0.3", features = ["zstd", "tokio"]}

View File

@@ -61,7 +61,7 @@ fn main() -> Result<()> {
.number_of_values(1)
.multiple_occurrences(true)
.help("Additional configuration overrides of the ones from the toml config file (or new ones to add there).
Any option has to be a valid toml document, example: `-c=\"foo='hey'\"` `-c=\"foo={value=1}\"`"),
Any option has to be a valid toml document, example: `-c \"foo='hey'\"` `-c \"foo={value=1}\"`"),
)
.get_matches();
@@ -115,14 +115,7 @@ fn main() -> Result<()> {
option_line
)
})?;
for (key, item) in doc.iter() {
if key == "id" {
anyhow::ensure!(
init,
"node id can only be set during pageserver init and cannot be overridden"
);
}
toml.insert(key, item.clone());
}
}

View File

@@ -16,9 +16,10 @@ use std::{
};
use tracing::*;
use zenith_utils::crashsafe_dir;
use zenith_utils::logging;
use zenith_utils::lsn::Lsn;
use zenith_utils::zid::{ZTenantId, ZTimelineId};
use zenith_utils::{crashsafe_dir, logging};
use crate::walredo::WalRedoManager;
use crate::CheckpointConfig;

View File

@@ -8,7 +8,7 @@ use anyhow::{bail, ensure, Context, Result};
use toml_edit;
use toml_edit::{Document, Item};
use zenith_utils::postgres_backend::AuthType;
use zenith_utils::zid::{ZNodeId, ZTenantId, ZTimelineId};
use zenith_utils::zid::{ZTenantId, ZTimelineId};
use std::convert::TryInto;
use std::env;
@@ -36,9 +36,6 @@ pub mod defaults {
pub const DEFAULT_GC_HORIZON: u64 = 64 * 1024 * 1024;
pub const DEFAULT_GC_PERIOD: &str = "100 s";
pub const DEFAULT_WAIT_LSN_TIMEOUT: &str = "60 s";
pub const DEFAULT_WAL_REDO_TIMEOUT: &str = "60 s";
pub const DEFAULT_SUPERUSER: &str = "zenith_admin";
pub const DEFAULT_REMOTE_STORAGE_MAX_CONCURRENT_SYNC: usize = 100;
pub const DEFAULT_REMOTE_STORAGE_MAX_SYNC_ERRORS: u32 = 10;
@@ -62,9 +59,6 @@ pub mod defaults {
#gc_period = '{DEFAULT_GC_PERIOD}'
#gc_horizon = {DEFAULT_GC_HORIZON}
#wait_lsn_timeout = '{DEFAULT_WAIT_LSN_TIMEOUT}'
#wal_redo_timeout = '{DEFAULT_WAL_REDO_TIMEOUT}'
#max_file_descriptors = {DEFAULT_MAX_FILE_DESCRIPTORS}
# initial superuser role name to use when creating a new tenant
@@ -78,10 +72,6 @@ pub mod defaults {
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct PageServerConf {
// Identifier of that particular pageserver so e g safekeepers
// can safely distinguish different pageservers
pub id: ZNodeId,
/// Example (default): 127.0.0.1:64000
pub listen_pg_addr: String,
/// Example (default): 127.0.0.1:9898
@@ -95,12 +85,6 @@ pub struct PageServerConf {
pub gc_horizon: u64,
pub gc_period: Duration,
// Timeout when waiting for WAL receiver to catch up to an LSN given in a GetPage@LSN call.
pub wait_lsn_timeout: Duration,
// How long to wait for WAL redo to complete.
pub wal_redo_timeout: Duration,
pub superuser: String,
pub page_cache_size: usize,
@@ -122,206 +106,6 @@ pub struct PageServerConf {
pub remote_storage_config: Option<RemoteStorageConfig>,
}
// use dedicated enum for builder to better indicate the intention
// and avoid possible confusion with nested options
pub enum BuilderValue<T> {
Set(T),
NotSet,
}
impl<T> BuilderValue<T> {
pub fn ok_or<E>(self, err: E) -> Result<T, E> {
match self {
Self::Set(v) => Ok(v),
Self::NotSet => Err(err),
}
}
}
// needed to simplify config construction
struct PageServerConfigBuilder {
listen_pg_addr: BuilderValue<String>,
listen_http_addr: BuilderValue<String>,
checkpoint_distance: BuilderValue<u64>,
checkpoint_period: BuilderValue<Duration>,
gc_horizon: BuilderValue<u64>,
gc_period: BuilderValue<Duration>,
wait_lsn_timeout: BuilderValue<Duration>,
wal_redo_timeout: BuilderValue<Duration>,
superuser: BuilderValue<String>,
page_cache_size: BuilderValue<usize>,
max_file_descriptors: BuilderValue<usize>,
workdir: BuilderValue<PathBuf>,
pg_distrib_dir: BuilderValue<PathBuf>,
auth_type: BuilderValue<AuthType>,
//
auth_validation_public_key_path: BuilderValue<Option<PathBuf>>,
remote_storage_config: BuilderValue<Option<RemoteStorageConfig>>,
id: BuilderValue<ZNodeId>,
}
impl Default for PageServerConfigBuilder {
fn default() -> Self {
use self::BuilderValue::*;
use defaults::*;
Self {
listen_pg_addr: Set(DEFAULT_PG_LISTEN_ADDR.to_string()),
listen_http_addr: Set(DEFAULT_HTTP_LISTEN_ADDR.to_string()),
checkpoint_distance: Set(DEFAULT_CHECKPOINT_DISTANCE),
checkpoint_period: Set(humantime::parse_duration(DEFAULT_CHECKPOINT_PERIOD)
.expect("cannot parse default checkpoint period")),
gc_horizon: Set(DEFAULT_GC_HORIZON),
gc_period: Set(humantime::parse_duration(DEFAULT_GC_PERIOD)
.expect("cannot parse default gc period")),
wait_lsn_timeout: Set(humantime::parse_duration(DEFAULT_WAIT_LSN_TIMEOUT)
.expect("cannot parse default wait lsn timeout")),
wal_redo_timeout: Set(humantime::parse_duration(DEFAULT_WAL_REDO_TIMEOUT)
.expect("cannot parse default wal redo timeout")),
superuser: Set(DEFAULT_SUPERUSER.to_string()),
page_cache_size: Set(DEFAULT_PAGE_CACHE_SIZE),
max_file_descriptors: Set(DEFAULT_MAX_FILE_DESCRIPTORS),
workdir: Set(PathBuf::new()),
pg_distrib_dir: Set(env::current_dir()
.expect("cannot access current directory")
.join("tmp_install")),
auth_type: Set(AuthType::Trust),
auth_validation_public_key_path: Set(None),
remote_storage_config: Set(None),
id: NotSet,
}
}
}
impl PageServerConfigBuilder {
pub fn listen_pg_addr(&mut self, listen_pg_addr: String) {
self.listen_pg_addr = BuilderValue::Set(listen_pg_addr)
}
pub fn listen_http_addr(&mut self, listen_http_addr: String) {
self.listen_http_addr = BuilderValue::Set(listen_http_addr)
}
pub fn checkpoint_distance(&mut self, checkpoint_distance: u64) {
self.checkpoint_distance = BuilderValue::Set(checkpoint_distance)
}
pub fn checkpoint_period(&mut self, checkpoint_period: Duration) {
self.checkpoint_period = BuilderValue::Set(checkpoint_period)
}
pub fn gc_horizon(&mut self, gc_horizon: u64) {
self.gc_horizon = BuilderValue::Set(gc_horizon)
}
pub fn gc_period(&mut self, gc_period: Duration) {
self.gc_period = BuilderValue::Set(gc_period)
}
pub fn wait_lsn_timeout(&mut self, wait_lsn_timeout: Duration) {
self.wait_lsn_timeout = BuilderValue::Set(wait_lsn_timeout)
}
pub fn wal_redo_timeout(&mut self, wal_redo_timeout: Duration) {
self.wal_redo_timeout = BuilderValue::Set(wal_redo_timeout)
}
pub fn superuser(&mut self, superuser: String) {
self.superuser = BuilderValue::Set(superuser)
}
pub fn page_cache_size(&mut self, page_cache_size: usize) {
self.page_cache_size = BuilderValue::Set(page_cache_size)
}
pub fn max_file_descriptors(&mut self, max_file_descriptors: usize) {
self.max_file_descriptors = BuilderValue::Set(max_file_descriptors)
}
pub fn workdir(&mut self, workdir: PathBuf) {
self.workdir = BuilderValue::Set(workdir)
}
pub fn pg_distrib_dir(&mut self, pg_distrib_dir: PathBuf) {
self.pg_distrib_dir = BuilderValue::Set(pg_distrib_dir)
}
pub fn auth_type(&mut self, auth_type: AuthType) {
self.auth_type = BuilderValue::Set(auth_type)
}
pub fn auth_validation_public_key_path(
&mut self,
auth_validation_public_key_path: Option<PathBuf>,
) {
self.auth_validation_public_key_path = BuilderValue::Set(auth_validation_public_key_path)
}
pub fn remote_storage_config(&mut self, remote_storage_config: Option<RemoteStorageConfig>) {
self.remote_storage_config = BuilderValue::Set(remote_storage_config)
}
pub fn id(&mut self, node_id: ZNodeId) {
self.id = BuilderValue::Set(node_id)
}
pub fn build(self) -> Result<PageServerConf> {
Ok(PageServerConf {
listen_pg_addr: self
.listen_pg_addr
.ok_or(anyhow::anyhow!("missing listen_pg_addr"))?,
listen_http_addr: self
.listen_http_addr
.ok_or(anyhow::anyhow!("missing listen_http_addr"))?,
checkpoint_distance: self
.checkpoint_distance
.ok_or(anyhow::anyhow!("missing checkpoint_distance"))?,
checkpoint_period: self
.checkpoint_period
.ok_or(anyhow::anyhow!("missing checkpoint_period"))?,
gc_horizon: self
.gc_horizon
.ok_or(anyhow::anyhow!("missing gc_horizon"))?,
gc_period: self.gc_period.ok_or(anyhow::anyhow!("missing gc_period"))?,
wait_lsn_timeout: self
.wait_lsn_timeout
.ok_or(anyhow::anyhow!("missing wait_lsn_timeout"))?,
wal_redo_timeout: self
.wal_redo_timeout
.ok_or(anyhow::anyhow!("missing wal_redo_timeout"))?,
superuser: self.superuser.ok_or(anyhow::anyhow!("missing superuser"))?,
page_cache_size: self
.page_cache_size
.ok_or(anyhow::anyhow!("missing page_cache_size"))?,
max_file_descriptors: self
.max_file_descriptors
.ok_or(anyhow::anyhow!("missing max_file_descriptors"))?,
workdir: self.workdir.ok_or(anyhow::anyhow!("missing workdir"))?,
pg_distrib_dir: self
.pg_distrib_dir
.ok_or(anyhow::anyhow!("missing pg_distrib_dir"))?,
auth_type: self.auth_type.ok_or(anyhow::anyhow!("missing auth_type"))?,
auth_validation_public_key_path: self
.auth_validation_public_key_path
.ok_or(anyhow::anyhow!("missing auth_validation_public_key_path"))?,
remote_storage_config: self
.remote_storage_config
.ok_or(anyhow::anyhow!("missing remote_storage_config"))?,
id: self.id.ok_or(anyhow::anyhow!("missing id"))?,
})
}
}
/// External backup storage configuration, enough for creating a client for that storage.
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct RemoteStorageConfig {
@@ -437,41 +221,57 @@ impl PageServerConf {
///
/// This leaves any options not present in the file in the built-in defaults.
pub fn parse_and_validate(toml: &Document, workdir: &Path) -> Result<Self> {
let mut builder = PageServerConfigBuilder::default();
builder.workdir(workdir.to_owned());
use defaults::*;
let mut conf = PageServerConf {
workdir: workdir.to_path_buf(),
listen_pg_addr: DEFAULT_PG_LISTEN_ADDR.to_string(),
listen_http_addr: DEFAULT_HTTP_LISTEN_ADDR.to_string(),
checkpoint_distance: DEFAULT_CHECKPOINT_DISTANCE,
checkpoint_period: humantime::parse_duration(DEFAULT_CHECKPOINT_PERIOD)?,
gc_horizon: DEFAULT_GC_HORIZON,
gc_period: humantime::parse_duration(DEFAULT_GC_PERIOD)?,
page_cache_size: DEFAULT_PAGE_CACHE_SIZE,
max_file_descriptors: DEFAULT_MAX_FILE_DESCRIPTORS,
pg_distrib_dir: PathBuf::new(),
auth_validation_public_key_path: None,
auth_type: AuthType::Trust,
remote_storage_config: None,
superuser: DEFAULT_SUPERUSER.to_string(),
};
for (key, item) in toml.iter() {
match key {
"listen_pg_addr" => builder.listen_pg_addr(parse_toml_string(key, item)?),
"listen_http_addr" => builder.listen_http_addr(parse_toml_string(key, item)?),
"checkpoint_distance" => builder.checkpoint_distance(parse_toml_u64(key, item)?),
"checkpoint_period" => builder.checkpoint_period(parse_toml_duration(key, item)?),
"gc_horizon" => builder.gc_horizon(parse_toml_u64(key, item)?),
"gc_period" => builder.gc_period(parse_toml_duration(key, item)?),
"wait_lsn_timeout" => builder.wait_lsn_timeout(parse_toml_duration(key, item)?),
"wal_redo_timeout" => builder.wal_redo_timeout(parse_toml_duration(key, item)?),
"initial_superuser_name" => builder.superuser(parse_toml_string(key, item)?),
"page_cache_size" => builder.page_cache_size(parse_toml_u64(key, item)? as usize),
"listen_pg_addr" => conf.listen_pg_addr = parse_toml_string(key, item)?,
"listen_http_addr" => conf.listen_http_addr = parse_toml_string(key, item)?,
"checkpoint_distance" => conf.checkpoint_distance = parse_toml_u64(key, item)?,
"checkpoint_period" => conf.checkpoint_period = parse_toml_duration(key, item)?,
"gc_horizon" => conf.gc_horizon = parse_toml_u64(key, item)?,
"gc_period" => conf.gc_period = parse_toml_duration(key, item)?,
"initial_superuser_name" => conf.superuser = parse_toml_string(key, item)?,
"page_cache_size" => conf.page_cache_size = parse_toml_u64(key, item)? as usize,
"max_file_descriptors" => {
builder.max_file_descriptors(parse_toml_u64(key, item)? as usize)
conf.max_file_descriptors = parse_toml_u64(key, item)? as usize
}
"pg_distrib_dir" => {
builder.pg_distrib_dir(PathBuf::from(parse_toml_string(key, item)?))
conf.pg_distrib_dir = PathBuf::from(parse_toml_string(key, item)?)
}
"auth_validation_public_key_path" => builder.auth_validation_public_key_path(Some(
PathBuf::from(parse_toml_string(key, item)?),
)),
"auth_type" => builder.auth_type(parse_toml_auth_type(key, item)?),
"auth_validation_public_key_path" => {
conf.auth_validation_public_key_path =
Some(PathBuf::from(parse_toml_string(key, item)?))
}
"auth_type" => conf.auth_type = parse_toml_auth_type(key, item)?,
"remote_storage" => {
builder.remote_storage_config(Some(Self::parse_remote_storage_config(item)?))
conf.remote_storage_config = Some(Self::parse_remote_storage_config(item)?)
}
"id" => builder.id(ZNodeId(parse_toml_u64(key, item)?)),
_ => bail!("unrecognized pageserver option '{}'", key),
}
}
let mut conf = builder.build().context("invalid config")?;
if conf.auth_type == AuthType::ZenithJWT {
let auth_validation_public_key_path = conf
.auth_validation_public_key_path
@@ -485,6 +285,9 @@ impl PageServerConf {
);
}
if conf.pg_distrib_dir == PathBuf::new() {
conf.pg_distrib_dir = env::current_dir()?.join("tmp_install")
};
if !conf.pg_distrib_dir.join("bin/postgres").exists() {
bail!(
"Can't find postgres binary at {}",
@@ -579,13 +382,10 @@ impl PageServerConf {
#[cfg(test)]
pub fn dummy_conf(repo_dir: PathBuf) -> Self {
PageServerConf {
id: ZNodeId(0),
checkpoint_distance: defaults::DEFAULT_CHECKPOINT_DISTANCE,
checkpoint_period: Duration::from_secs(10),
gc_horizon: defaults::DEFAULT_GC_HORIZON,
gc_period: Duration::from_secs(10),
wait_lsn_timeout: Duration::from_secs(60),
wal_redo_timeout: Duration::from_secs(60),
page_cache_size: defaults::DEFAULT_PAGE_CACHE_SIZE,
max_file_descriptors: defaults::DEFAULT_MAX_FILE_DESCRIPTORS,
listen_pg_addr: defaults::DEFAULT_PG_LISTEN_ADDR.to_string(),
@@ -656,24 +456,20 @@ checkpoint_period = '111 s'
gc_period = '222 s'
gc_horizon = 222
wait_lsn_timeout = '111 s'
wal_redo_timeout = '111 s'
page_cache_size = 444
max_file_descriptors = 333
# initial superuser role name to use when creating a new tenant
initial_superuser_name = 'zzzz'
id = 10
"#;
"#;
#[test]
fn parse_defaults() -> anyhow::Result<()> {
let tempdir = tempdir()?;
let (workdir, pg_distrib_dir) = prepare_fs(&tempdir)?;
// we have to create dummy pathes to overcome the validation errors
let config_string = format!("pg_distrib_dir='{}'\nid=10", pg_distrib_dir.display());
let config_string = format!("pg_distrib_dir='{}'", pg_distrib_dir.display());
let toml = config_string.parse()?;
let parsed_config =
@@ -684,15 +480,12 @@ id = 10
assert_eq!(
parsed_config,
PageServerConf {
id: ZNodeId(10),
listen_pg_addr: defaults::DEFAULT_PG_LISTEN_ADDR.to_string(),
listen_http_addr: defaults::DEFAULT_HTTP_LISTEN_ADDR.to_string(),
checkpoint_distance: defaults::DEFAULT_CHECKPOINT_DISTANCE,
checkpoint_period: humantime::parse_duration(defaults::DEFAULT_CHECKPOINT_PERIOD)?,
gc_horizon: defaults::DEFAULT_GC_HORIZON,
gc_period: humantime::parse_duration(defaults::DEFAULT_GC_PERIOD)?,
wait_lsn_timeout: humantime::parse_duration(defaults::DEFAULT_WAIT_LSN_TIMEOUT)?,
wal_redo_timeout: humantime::parse_duration(defaults::DEFAULT_WAL_REDO_TIMEOUT)?,
superuser: defaults::DEFAULT_SUPERUSER.to_string(),
page_cache_size: defaults::DEFAULT_PAGE_CACHE_SIZE,
max_file_descriptors: defaults::DEFAULT_MAX_FILE_DESCRIPTORS,
@@ -728,15 +521,12 @@ id = 10
assert_eq!(
parsed_config,
PageServerConf {
id: ZNodeId(10),
listen_pg_addr: "127.0.0.1:64000".to_string(),
listen_http_addr: "127.0.0.1:9898".to_string(),
checkpoint_distance: 111,
checkpoint_period: Duration::from_secs(111),
gc_horizon: 222,
gc_period: Duration::from_secs(222),
wait_lsn_timeout: Duration::from_secs(111),
wal_redo_timeout: Duration::from_secs(111),
superuser: "zzzz".to_string(),
page_cache_size: 444,
max_file_descriptors: 333,

View File

@@ -1,7 +1,6 @@
use serde::{Deserialize, Serialize};
use crate::ZTenantId;
use zenith_utils::zid::ZNodeId;
#[derive(Serialize, Deserialize)]
pub struct BranchCreateRequest {
@@ -16,8 +15,3 @@ pub struct TenantCreateRequest {
#[serde(with = "hex")]
pub tenant_id: ZTenantId,
}
#[derive(Serialize)]
pub struct StatusResponse {
pub id: ZNodeId,
}

View File

@@ -17,11 +17,6 @@ paths:
application/json:
schema:
type: object
required:
- id
properties:
id:
type: integer
/v1/timeline/{tenant_id}:
parameters:
- name: tenant_id
@@ -239,7 +234,9 @@ paths:
content:
application/json:
schema:
$ref: "#/components/schemas/BranchInfo"
type: array
items:
$ref: "#/components/schemas/BranchInfo"
"400":
description: Malformed branch create request
content:
@@ -373,15 +370,12 @@ components:
format: hex
ancestor_id:
type: string
format: hex
ancestor_lsn:
type: string
current_logical_size:
type: integer
current_logical_size_non_incremental:
type: integer
latest_valid_lsn:
type: integer
TimelineInfo:
type: object
required:

View File

@@ -1,6 +1,7 @@
use std::sync::Arc;
use anyhow::{Context, Result};
use hyper::header;
use hyper::StatusCode;
use hyper::{Body, Request, Response, Uri};
use serde::Serialize;
@@ -19,11 +20,9 @@ use zenith_utils::http::{
};
use zenith_utils::http::{RequestExt, RouterBuilder};
use zenith_utils::lsn::Lsn;
use zenith_utils::zid::HexZTimelineId;
use zenith_utils::zid::ZTimelineId;
use zenith_utils::zid::{opt_display_serde, ZTimelineId};
use super::models::BranchCreateRequest;
use super::models::StatusResponse;
use super::models::TenantCreateRequest;
use crate::branches::BranchInfo;
use crate::repository::RepositoryTimeline;
@@ -65,12 +64,12 @@ fn get_config(request: &Request<Body>) -> &'static PageServerConf {
}
// healthcheck handler
async fn status_handler(request: Request<Body>) -> Result<Response<Body>, ApiError> {
let config = get_config(&request);
Ok(json_response(
StatusCode::OK,
StatusResponse { id: config.id },
)?)
async fn status_handler(_: Request<Body>) -> Result<Response<Body>, ApiError> {
Ok(Response::builder()
.status(StatusCode::OK)
.header(header::CONTENT_TYPE, "application/json")
.body(Body::from("{}"))
.map_err(ApiError::from_err)?)
}
async fn branch_create_handler(mut request: Request<Body>) -> Result<Response<Body>, ApiError> {
@@ -199,7 +198,8 @@ enum TimelineInfo {
timeline_id: ZTimelineId,
#[serde(with = "hex")]
tenant_id: ZTenantId,
ancestor_timeline_id: Option<HexZTimelineId>,
#[serde(with = "opt_display_serde")]
ancestor_timeline_id: Option<ZTimelineId>,
last_record_lsn: Lsn,
prev_record_lsn: Lsn,
disk_consistent_lsn: Lsn,
@@ -232,9 +232,7 @@ async fn timeline_detail_handler(request: Request<Body>) -> Result<Response<Body
Some(timeline) => TimelineInfo::Local {
timeline_id,
tenant_id,
ancestor_timeline_id: timeline
.get_ancestor_timeline_id()
.map(HexZTimelineId::from),
ancestor_timeline_id: timeline.get_ancestor_timeline_id(),
disk_consistent_lsn: timeline.get_disk_consistent_lsn(),
last_record_lsn: timeline.get_last_record_lsn(),
prev_record_lsn: timeline.get_prev_record_lsn(),

View File

@@ -29,7 +29,7 @@ use std::ops::{Bound::Included, Deref};
use std::path::{Path, PathBuf};
use std::sync::atomic::{self, AtomicBool, AtomicUsize};
use std::sync::{Arc, Mutex, MutexGuard, RwLock, RwLockReadGuard};
use std::time::Instant;
use std::time::{Duration, Instant};
use self::metadata::{metadata_path, TimelineMetadata, METADATA_FILE_NAME};
use crate::config::PageServerConf;
@@ -64,6 +64,7 @@ mod inmemory_layer;
mod interval_tree;
mod layer_map;
pub mod metadata;
mod page_versions;
mod par_fsync;
mod storage_layer;
@@ -82,6 +83,9 @@ pub use crate::layered_repository::ephemeral_file::writeback as writeback_epheme
static ZERO_PAGE: Bytes = Bytes::from_static(&[0u8; 8192]);
// Timeout when waiting for WAL receiver to catch up to an LSN given in a GetPage@LSN call.
static TIMEOUT: Duration = Duration::from_secs(60);
// Metrics collected on operations on the storage repository.
lazy_static! {
static ref STORAGE_TIME: HistogramVec = register_histogram_vec!(
@@ -812,7 +816,7 @@ impl Timeline for LayeredTimeline {
);
self.last_record_lsn
.wait_for_timeout(lsn, self.conf.wait_lsn_timeout)
.wait_for_timeout(lsn, TIMEOUT)
.with_context(|| {
format!(
"Timed out while waiting for WAL record at LSN {} to arrive, last_record_lsn {} disk consistent LSN={}",
@@ -893,11 +897,12 @@ impl Timeline for LayeredTimeline {
let seg = SegmentTag { rel, segno: 0 };
let result = if let Some((layer, lsn)) = self.get_layer_for_read(seg, lsn)? {
layer.get_seg_exists(lsn)?
let result;
if let Some((layer, lsn)) = self.get_layer_for_read(seg, lsn)? {
result = layer.get_seg_exists(lsn)?;
} else {
false
};
result = false;
}
trace!("get_rel_exists: {} at {} -> {}", rel, lsn, result);
Ok(result)
@@ -1939,21 +1944,22 @@ impl LayeredTimeline {
// for redo.
let rel = seg.rel;
let rel_blknum = seg.segno * RELISH_SEG_SIZE + seg_blknum;
let cached_page_img = match self.lookup_cached_page(&rel, rel_blknum, lsn) {
let (cached_lsn_opt, cached_page_opt) = match self.lookup_cached_page(&rel, rel_blknum, lsn)
{
Some((cached_lsn, cached_img)) => {
match cached_lsn.cmp(&lsn) {
cmp::Ordering::Less => {} // there might be WAL between cached_lsn and lsn, we need to check
cmp::Ordering::Equal => return Ok(cached_img), // exact LSN match, return the image
cmp::Ordering::Greater => panic!(), // the returned lsn should never be after the requested lsn
}
Some((cached_lsn, cached_img))
(Some(cached_lsn), Some((cached_lsn, cached_img)))
}
None => None,
None => (None, None),
};
let mut data = PageReconstructData {
records: Vec::new(),
page_img: cached_page_img,
page_img: None,
};
// Holds an Arc reference to 'layer_ref' when iterating in the loop below.
@@ -1966,14 +1972,15 @@ impl LayeredTimeline {
let mut curr_lsn = lsn;
loop {
let result = layer_ref
.get_page_reconstruct_data(seg_blknum, curr_lsn, &mut data)
.get_page_reconstruct_data(seg_blknum, curr_lsn, cached_lsn_opt, &mut data)
.with_context(|| {
format!(
"Failed to get reconstruct data {} {:?} {} {}",
"Failed to get reconstruct data {} {:?} {} {} {:?}",
layer_ref.get_seg_tag(),
layer_ref.filename(),
seg_blknum,
curr_lsn,
cached_lsn_opt,
)
})?;
match result {
@@ -2020,6 +2027,16 @@ impl LayeredTimeline {
lsn,
);
}
PageReconstructResult::Cached => {
let (cached_lsn, cached_img) = cached_page_opt.unwrap();
assert!(data.page_img.is_none());
if let Some((first_rec_lsn, first_rec)) = data.records.first() {
assert!(&cached_lsn < first_rec_lsn);
assert!(!first_rec.will_init());
}
data.page_img = Some(cached_img);
break;
}
}
}
@@ -2041,12 +2058,12 @@ impl LayeredTimeline {
// If we have a page image, and no WAL, we're all set
if data.records.is_empty() {
if let Some((img_lsn, img)) = &data.page_img {
if let Some(img) = &data.page_img {
trace!(
"found page image for blk {} in {} at {}, no WAL redo required",
rel_blknum,
rel,
img_lsn
request_lsn
);
Ok(img.clone())
} else {
@@ -2073,13 +2090,11 @@ impl LayeredTimeline {
);
Ok(ZERO_PAGE.clone())
} else {
let base_img = if let Some((_lsn, img)) = data.page_img {
if data.page_img.is_some() {
trace!("found {} WAL records and a base image for blk {} in {} at {}, performing WAL redo", data.records.len(), rel_blknum, rel, request_lsn);
Some(img)
} else {
trace!("found {} WAL records that will init the page for blk {} in {} at {}, performing WAL redo", data.records.len(), rel_blknum, rel, request_lsn);
None
};
}
let last_rec_lsn = data.records.last().unwrap().0;
@@ -2087,7 +2102,7 @@ impl LayeredTimeline {
rel,
rel_blknum,
request_lsn,
base_img,
data.page_img.clone(),
data.records,
)?;
@@ -2346,157 +2361,3 @@ fn rename_to_backup(path: PathBuf) -> anyhow::Result<()> {
bail!("couldn't find an unused backup number for {:?}", path)
}
///
/// Tests that are specific to the layered storage format.
///
/// There are more unit tests in repository.rs that work through the
/// Repository interface and are expected to work regardless of the
/// file format and directory layout. The test here are more low level.
///
#[cfg(test)]
mod tests {
use super::*;
use crate::repository::repo_harness::*;
#[test]
fn corrupt_metadata() -> Result<()> {
const TEST_NAME: &str = "corrupt_metadata";
let harness = RepoHarness::create(TEST_NAME)?;
let repo = harness.load();
repo.create_empty_timeline(TIMELINE_ID, Lsn(0))?;
drop(repo);
let metadata_path = harness.timeline_path(&TIMELINE_ID).join(METADATA_FILE_NAME);
assert!(metadata_path.is_file());
let mut metadata_bytes = std::fs::read(&metadata_path)?;
assert_eq!(metadata_bytes.len(), 512);
metadata_bytes[512 - 4 - 2] ^= 1;
std::fs::write(metadata_path, metadata_bytes)?;
let new_repo = harness.load();
let err = new_repo.get_timeline(TIMELINE_ID).err().unwrap();
assert_eq!(err.to_string(), "failed to load metadata");
assert_eq!(
err.source().unwrap().to_string(),
"metadata checksum mismatch"
);
Ok(())
}
///
/// Test the logic in 'load_layer_map' that removes layer files that are
/// newer than 'disk_consistent_lsn'.
///
#[test]
fn future_layerfiles() -> Result<()> {
const TEST_NAME: &str = "future_layerfiles";
let harness = RepoHarness::create(TEST_NAME)?;
let repo = harness.load();
// Create a timeline with disk_consistent_lsn = 8000
let tline = repo.create_empty_timeline(TIMELINE_ID, Lsn(0x8000))?;
let writer = tline.writer();
writer.advance_last_record_lsn(Lsn(0x8000));
drop(writer);
repo.checkpoint_iteration(CheckpointConfig::Forced)?;
drop(repo);
let timeline_path = harness.timeline_path(&TIMELINE_ID);
let make_empty_file = |filename: &str| -> std::io::Result<()> {
let path = timeline_path.join(filename);
assert!(!path.exists());
std::fs::write(&path, &[])?;
Ok(())
};
// Helper function to check that a relation file exists, and a corresponding
// <filename>.0.old file does not.
let assert_exists = |filename: &str| {
let path = timeline_path.join(filename);
assert!(path.exists(), "file {} was removed", filename);
// Check that there is no .old file
let backup_path = timeline_path.join(format!("{}.0.old", filename));
assert!(
!backup_path.exists(),
"unexpected backup file {}",
backup_path.display()
);
};
// Helper function to check that a relation file does *not* exists, and a corresponding
// <filename>.<num>.old file does.
let assert_is_renamed = |filename: &str, num: u32| {
let path = timeline_path.join(filename);
assert!(
!path.exists(),
"file {} was not removed as expected",
filename
);
let backup_path = timeline_path.join(format!("{}.{}.old", filename, num));
assert!(
backup_path.exists(),
"backup file {} was not created",
backup_path.display()
);
};
// These files are considered to be in the future and will be renamed out
// of the way
let future_filenames = vec![
format!("pg_control_0_{:016X}", 0x8001),
format!("pg_control_0_{:016X}_{:016X}", 0x8001, 0x8008),
];
// But these are not:
let past_filenames = vec![
format!("pg_control_0_{:016X}", 0x8000),
format!("pg_control_0_{:016X}_{:016X}", 0x7000, 0x8001),
];
for filename in future_filenames.iter().chain(past_filenames.iter()) {
make_empty_file(filename)?;
}
// Load the timeline. This will cause the files in the "future" to be renamed
// away.
let new_repo = harness.load();
new_repo.get_timeline(TIMELINE_ID).unwrap();
drop(new_repo);
for filename in future_filenames.iter() {
assert_is_renamed(filename, 0);
}
for filename in past_filenames.iter() {
assert_exists(filename);
}
// Create the future files again, and load again. They should be renamed to
// *.1.old this time.
for filename in future_filenames.iter() {
make_empty_file(filename)?;
}
let new_repo = harness.load();
new_repo.get_timeline(TIMELINE_ID).unwrap();
drop(new_repo);
for filename in future_filenames.iter() {
assert_is_renamed(filename, 0);
assert_is_renamed(filename, 1);
}
for filename in past_filenames.iter() {
assert_exists(filename);
}
Ok(())
}
}

View File

@@ -208,15 +208,16 @@ impl Layer for DeltaLayer {
&self,
blknum: SegmentBlk,
lsn: Lsn,
cached_img_lsn: Option<Lsn>,
reconstruct_data: &mut PageReconstructData,
) -> Result<PageReconstructResult> {
let mut need_image = true;
assert!((0..RELISH_SEG_SIZE).contains(&blknum));
match &reconstruct_data.page_img {
Some((cached_lsn, _)) if &self.end_lsn <= cached_lsn => {
return Ok(PageReconstructResult::Complete)
match &cached_img_lsn {
Some(cached_lsn) if &self.end_lsn <= cached_lsn => {
return Ok(PageReconstructResult::Cached)
}
_ => {}
}
@@ -239,9 +240,9 @@ impl Layer for DeltaLayer {
.iter()
.rev();
for ((_blknum, pv_lsn), blob_range) in iter {
match &reconstruct_data.page_img {
Some((cached_lsn, _)) if pv_lsn <= cached_lsn => {
return Ok(PageReconstructResult::Complete)
match &cached_img_lsn {
Some(cached_lsn) if pv_lsn <= cached_lsn => {
return Ok(PageReconstructResult::Cached)
}
_ => {}
}
@@ -251,7 +252,7 @@ impl Layer for DeltaLayer {
match pv {
PageVersion::Page(img) => {
// Found a page image, return it
reconstruct_data.page_img = Some((*pv_lsn, img));
reconstruct_data.page_img = Some(img);
need_image = false;
break;
}

View File

@@ -145,15 +145,14 @@ impl Layer for ImageLayer {
&self,
blknum: SegmentBlk,
lsn: Lsn,
cached_img_lsn: Option<Lsn>,
reconstruct_data: &mut PageReconstructData,
) -> Result<PageReconstructResult> {
assert!((0..RELISH_SEG_SIZE).contains(&blknum));
assert!(lsn >= self.lsn);
match reconstruct_data.page_img {
Some((cached_lsn, _)) if self.lsn <= cached_lsn => {
return Ok(PageReconstructResult::Complete)
}
match cached_img_lsn {
Some(cached_lsn) if self.lsn <= cached_lsn => return Ok(PageReconstructResult::Cached),
_ => {}
}
@@ -196,7 +195,7 @@ impl Layer for ImageLayer {
}
};
reconstruct_data.page_img = Some((self.lsn, Bytes::from(buf)));
reconstruct_data.page_img = Some(Bytes::from(buf));
Ok(PageReconstructResult::Complete)
}

View File

@@ -20,15 +20,13 @@ use crate::{ZTenantId, ZTimelineId};
use anyhow::{ensure, Result};
use bytes::Bytes;
use log::*;
use std::collections::HashMap;
use std::io::Seek;
use std::os::unix::fs::FileExt;
use std::path::PathBuf;
use std::sync::{Arc, RwLock};
use zenith_utils::bin_ser::BeSer;
use zenith_utils::lsn::Lsn;
use zenith_utils::vec_map::VecMap;
use super::page_versions::PageVersions;
pub struct InMemoryLayer {
conf: &'static PageServerConf,
tenantid: ZTenantId,
@@ -73,15 +71,11 @@ pub struct InMemoryLayerInner {
/// The drop LSN is recorded in [`end_lsn`].
dropped: bool,
/// The PageVersion structs are stored in a serialized format in this file.
/// Each serialized PageVersion is preceded by a 'u32' length field.
/// 'page_versions' map stores offsets into this file.
file: EphemeralFile,
/// Metadata about all versions of all pages in the layer is kept
/// here. Indexed by block number and LSN. The value is an offset
/// into the ephemeral file where the page version is stored.
page_versions: HashMap<SegmentBlk, VecMap<Lsn, u64>>,
///
/// All versions of all pages in the layer are are kept here.
/// Indexed by block number and LSN.
///
page_versions: PageVersions,
///
/// `seg_sizes` tracks the size of the segment at different points in time.
@@ -117,50 +111,6 @@ impl InMemoryLayerInner {
panic!("could not find seg size in in-memory layer");
}
}
///
/// Read a page version from the ephemeral file.
///
fn read_pv(&self, off: u64) -> Result<PageVersion> {
let mut buf = Vec::new();
self.read_pv_bytes(off, &mut buf)?;
Ok(PageVersion::des(&buf)?)
}
///
/// Read a page version from the ephemeral file, as raw bytes, at
/// the given offset. The bytes are read into 'buf', which is
/// expanded if necessary. Returns the size of the page version.
///
fn read_pv_bytes(&self, off: u64, buf: &mut Vec<u8>) -> Result<usize> {
// read length
let mut lenbuf = [0u8; 4];
self.file.read_exact_at(&mut lenbuf, off)?;
let len = u32::from_ne_bytes(lenbuf) as usize;
if buf.len() < len {
buf.resize(len, 0);
}
self.file.read_exact_at(&mut buf[0..len], off + 4)?;
Ok(len)
}
fn write_pv(&mut self, pv: &PageVersion) -> Result<u64> {
// remember starting position
let pos = self.file.stream_position()?;
// make room for the 'length' field by writing zeros as a placeholder.
self.file.seek(std::io::SeekFrom::Start(pos + 4)).unwrap();
pv.ser_into(&mut self.file).unwrap();
// write the 'length' field.
let len = self.file.stream_position()? - pos - 4;
let lenbuf = u32::to_ne_bytes(len as u32);
self.file.write_all_at(&lenbuf, pos)?;
Ok(pos)
}
}
impl Layer for InMemoryLayer {
@@ -170,11 +120,12 @@ impl Layer for InMemoryLayer {
fn filename(&self) -> PathBuf {
let inner = self.inner.read().unwrap();
let end_lsn = if let Some(drop_lsn) = inner.end_lsn {
drop_lsn
let end_lsn;
if let Some(drop_lsn) = inner.end_lsn {
end_lsn = drop_lsn;
} else {
Lsn(u64::MAX)
};
end_lsn = Lsn(u64::MAX);
}
let delta_filename = DeltaFileName {
seg: self.seg,
@@ -223,6 +174,7 @@ impl Layer for InMemoryLayer {
&self,
blknum: SegmentBlk,
lsn: Lsn,
cached_img_lsn: Option<Lsn>,
reconstruct_data: &mut PageReconstructData,
) -> Result<PageReconstructResult> {
let mut need_image = true;
@@ -233,31 +185,33 @@ impl Layer for InMemoryLayer {
let inner = self.inner.read().unwrap();
// Scan the page versions backwards, starting from `lsn`.
if let Some(vec_map) = inner.page_versions.get(&blknum) {
let slice = vec_map.slice_range(..=lsn);
for (entry_lsn, pos) in slice.iter().rev() {
match &reconstruct_data.page_img {
Some((cached_lsn, _)) if entry_lsn <= cached_lsn => {
return Ok(PageReconstructResult::Complete)
}
_ => {}
let iter = inner
.page_versions
.get_block_lsn_range(blknum, ..=lsn)
.iter()
.rev();
for (entry_lsn, pos) in iter {
match &cached_img_lsn {
Some(cached_lsn) if entry_lsn <= cached_lsn => {
return Ok(PageReconstructResult::Cached)
}
_ => {}
}
let pv = inner.read_pv(*pos)?;
match pv {
PageVersion::Page(img) => {
reconstruct_data.page_img = Some((*entry_lsn, img));
let pv = inner.page_versions.read_pv(*pos)?;
match pv {
PageVersion::Page(img) => {
reconstruct_data.page_img = Some(img);
need_image = false;
break;
}
PageVersion::Wal(rec) => {
reconstruct_data.records.push((*entry_lsn, rec.clone()));
if rec.will_init() {
// This WAL record initializes the page, so no need to go further back
need_image = false;
break;
}
PageVersion::Wal(rec) => {
reconstruct_data.records.push((*entry_lsn, rec.clone()));
if rec.will_init() {
// This WAL record initializes the page, so no need to go further back
need_image = false;
break;
}
}
}
}
}
@@ -363,22 +317,14 @@ impl Layer for InMemoryLayer {
println!("seg_sizes {}: {}", k, v);
}
// List the blocks in order
let mut page_versions: Vec<(&SegmentBlk, &VecMap<Lsn, u64>)> =
inner.page_versions.iter().collect();
page_versions.sort_by_key(|k| k.0);
for (blknum, lsn, pos) in inner.page_versions.ordered_page_version_iter(None) {
let pv = inner.page_versions.read_pv(pos)?;
let pv_description = match pv {
PageVersion::Page(_img) => "page",
PageVersion::Wal(_rec) => "wal",
};
for (blknum, versions) in page_versions {
for (lsn, off) in versions.as_slice() {
let pv = inner.read_pv(*off);
let pv_description = match pv {
Ok(PageVersion::Page(_img)) => "page",
Ok(PageVersion::Wal(_rec)) => "wal",
Err(_err) => "INVALID",
};
println!("blk {} at {}: {}\n", blknum, lsn, pv_description);
}
println!("blk {} at {}: {}\n", blknum, lsn, pv_description);
}
Ok(())
@@ -439,8 +385,7 @@ impl InMemoryLayer {
inner: RwLock::new(InMemoryLayerInner {
end_lsn: None,
dropped: false,
file,
page_versions: HashMap::new(),
page_versions: PageVersions::new(file),
seg_sizes,
latest_lsn: oldest_lsn,
}),
@@ -482,18 +427,14 @@ impl InMemoryLayer {
assert!(lsn >= inner.latest_lsn);
inner.latest_lsn = lsn;
// Write the page version to the file, and remember its offset in 'page_versions'
{
let off = inner.write_pv(&pv)?;
let vec_map = inner.page_versions.entry(blknum).or_default();
let old = vec_map.append_or_update_last(lsn, off).unwrap().0;
if old.is_some() {
// We already had an entry for this LSN. That's odd..
warn!(
"Page version of rel {} blk {} at {} already exists",
self.seg.rel, blknum, lsn
);
}
let old = inner.page_versions.append_or_update_last(blknum, lsn, pv)?;
if old.is_some() {
// We already had an entry for this LSN. That's odd..
warn!(
"Page version of rel {} blk {} at {} already exists",
self.seg.rel, blknum, lsn
);
}
// Also update the relation size, if this extended the relation.
@@ -527,19 +468,16 @@ impl InMemoryLayer {
gapblknum,
blknum
);
let old = inner
.page_versions
.append_or_update_last(gapblknum, lsn, zeropv)?;
// We already had an entry for this LSN. That's odd..
// Write the page version to the file, and remember its offset in
// 'page_versions'
{
let off = inner.write_pv(&zeropv)?;
let vec_map = inner.page_versions.entry(gapblknum).or_default();
let old = vec_map.append_or_update_last(lsn, off).unwrap().0;
if old.is_some() {
warn!(
"Page version of seg {} blk {} at {} already exists",
self.seg, gapblknum, lsn
);
}
if old.is_some() {
warn!(
"Page version of seg {} blk {} at {} already exists",
self.seg, blknum, lsn
);
}
}
@@ -632,8 +570,7 @@ impl InMemoryLayer {
inner: RwLock::new(InMemoryLayerInner {
end_lsn: None,
dropped: false,
file,
page_versions: HashMap::new(),
page_versions: PageVersions::new(file),
seg_sizes,
latest_lsn: oldest_lsn,
}),
@@ -662,10 +599,8 @@ impl InMemoryLayer {
assert!(lsn <= &end_lsn, "{:?} {:?}", lsn, end_lsn);
}
for (_blk, vec_map) in inner.page_versions.iter() {
for (lsn, _pos) in vec_map.as_slice() {
assert!(*lsn <= end_lsn);
}
for (_blk, lsn, _pv) in inner.page_versions.ordered_page_version_iter(None) {
assert!(lsn <= end_lsn);
}
}
}
@@ -743,19 +678,15 @@ impl InMemoryLayer {
self.is_dropped(),
)?;
// Write all page versions, in block + LSN order
// Write all page versions
let mut buf: Vec<u8> = Vec::new();
let pv_iter = inner.page_versions.iter();
let mut pages: Vec<(&SegmentBlk, &VecMap<Lsn, u64>)> = pv_iter.collect();
pages.sort_by_key(|(blknum, _vec_map)| *blknum);
for (blknum, vec_map) in pages {
for (lsn, pos) in vec_map.as_slice() {
if *lsn < delta_end_lsn {
let len = inner.read_pv_bytes(*pos, &mut buf)?;
delta_layer_writer.put_page_version(*blknum, *lsn, &buf[..len])?;
}
}
let page_versions_iter = inner
.page_versions
.ordered_page_version_iter(Some(delta_end_lsn));
for (blknum, lsn, pos) in page_versions_iter {
let len = inner.page_versions.read_pv_bytes(pos, &mut buf)?;
delta_layer_writer.put_page_version(blknum, lsn, &buf[..len])?;
}
// Create seg_sizes

View File

@@ -0,0 +1,268 @@
//!
//! Data structure to ingest incoming WAL into an append-only file.
//!
//! - The file is considered temporary, and will be discarded on crash
//! - based on a B-tree
//!
use std::os::unix::fs::FileExt;
use std::{collections::HashMap, ops::RangeBounds, slice};
use anyhow::Result;
use std::cmp::min;
use std::io::Seek;
use zenith_utils::{lsn::Lsn, vec_map::VecMap};
use super::storage_layer::PageVersion;
use crate::layered_repository::ephemeral_file::EphemeralFile;
use zenith_utils::bin_ser::BeSer;
const EMPTY_SLICE: &[(Lsn, u64)] = &[];
pub struct PageVersions {
map: HashMap<u32, VecMap<Lsn, u64>>,
/// The PageVersion structs are stored in a serialized format in this file.
/// Each serialized PageVersion is preceded by a 'u32' length field.
/// The 'map' stores offsets into this file.
file: EphemeralFile,
}
impl PageVersions {
pub fn new(file: EphemeralFile) -> PageVersions {
PageVersions {
map: HashMap::new(),
file,
}
}
pub fn append_or_update_last(
&mut self,
blknum: u32,
lsn: Lsn,
page_version: PageVersion,
) -> Result<Option<u64>> {
// remember starting position
let pos = self.file.stream_position()?;
// make room for the 'length' field by writing zeros as a placeholder.
self.file.seek(std::io::SeekFrom::Start(pos + 4)).unwrap();
page_version.ser_into(&mut self.file).unwrap();
// write the 'length' field.
let len = self.file.stream_position()? - pos - 4;
let lenbuf = u32::to_ne_bytes(len as u32);
self.file.write_all_at(&lenbuf, pos)?;
let map = self.map.entry(blknum).or_insert_with(VecMap::default);
Ok(map.append_or_update_last(lsn, pos as u64).unwrap().0)
}
/// Get all [`PageVersion`]s in a block
fn get_block_slice(&self, blknum: u32) -> &[(Lsn, u64)] {
self.map
.get(&blknum)
.map(VecMap::as_slice)
.unwrap_or(EMPTY_SLICE)
}
/// Get a range of [`PageVersions`] in a block
pub fn get_block_lsn_range<R: RangeBounds<Lsn>>(&self, blknum: u32, range: R) -> &[(Lsn, u64)] {
self.map
.get(&blknum)
.map(|vec_map| vec_map.slice_range(range))
.unwrap_or(EMPTY_SLICE)
}
/// Iterate through [`PageVersion`]s in (block, lsn) order.
/// If a [`cutoff_lsn`] is set, only show versions with `lsn < cutoff_lsn`
pub fn ordered_page_version_iter(&self, cutoff_lsn: Option<Lsn>) -> OrderedPageVersionIter<'_> {
let mut ordered_blocks: Vec<u32> = self.map.keys().cloned().collect();
ordered_blocks.sort_unstable();
let slice = ordered_blocks
.first()
.map(|&blknum| self.get_block_slice(blknum))
.unwrap_or(EMPTY_SLICE);
OrderedPageVersionIter {
page_versions: self,
ordered_blocks,
cur_block_idx: 0,
cutoff_lsn,
cur_slice_iter: slice.iter(),
}
}
///
/// Read a page version.
///
pub fn read_pv(&self, off: u64) -> Result<PageVersion> {
let mut buf = Vec::new();
self.read_pv_bytes(off, &mut buf)?;
Ok(PageVersion::des(&buf)?)
}
///
/// Read a page version, as raw bytes, at the given offset. The bytes
/// are read into 'buf', which is expanded if necessary. Returns the
/// size of the page version.
///
pub fn read_pv_bytes(&self, off: u64, buf: &mut Vec<u8>) -> Result<usize> {
// read length
let mut lenbuf = [0u8; 4];
self.file.read_exact_at(&mut lenbuf, off)?;
let len = u32::from_ne_bytes(lenbuf) as usize;
// Resize the buffer to fit the data, if needed.
//
// We don't shrink the buffer if it's larger than necessary. That avoids
// repeatedly shrinking and expanding when you reuse the same buffer to
// read multiple page versions. Expanding a Vec requires initializing the
// new bytes, which is a waste of time because we're immediately overwriting
// it, but there's no way to avoid it without resorting to unsafe code.
if buf.len() < len {
buf.resize(len, 0);
}
self.file.read_exact_at(&mut buf[0..len], off + 4)?;
Ok(len)
}
}
pub struct PageVersionReader<'a> {
file: &'a EphemeralFile,
pos: u64,
end_pos: u64,
}
impl<'a> std::io::Read for PageVersionReader<'a> {
fn read(&mut self, buf: &mut [u8]) -> Result<usize, std::io::Error> {
let len = min(buf.len(), (self.end_pos - self.pos) as usize);
let n = self.file.read_at(&mut buf[..len], self.pos)?;
self.pos += n as u64;
Ok(n)
}
}
pub struct OrderedPageVersionIter<'a> {
page_versions: &'a PageVersions,
ordered_blocks: Vec<u32>,
cur_block_idx: usize,
cutoff_lsn: Option<Lsn>,
cur_slice_iter: slice::Iter<'a, (Lsn, u64)>,
}
impl OrderedPageVersionIter<'_> {
fn is_lsn_before_cutoff(&self, lsn: &Lsn) -> bool {
if let Some(cutoff_lsn) = self.cutoff_lsn.as_ref() {
lsn < cutoff_lsn
} else {
true
}
}
}
impl<'a> Iterator for OrderedPageVersionIter<'a> {
type Item = (u32, Lsn, u64);
fn next(&mut self) -> Option<Self::Item> {
loop {
if let Some((lsn, pos)) = self.cur_slice_iter.next() {
if self.is_lsn_before_cutoff(lsn) {
let blknum = self.ordered_blocks[self.cur_block_idx];
return Some((blknum, *lsn, *pos));
}
}
let next_block_idx = self.cur_block_idx + 1;
let blknum: u32 = *self.ordered_blocks.get(next_block_idx)?;
self.cur_block_idx = next_block_idx;
self.cur_slice_iter = self.page_versions.get_block_slice(blknum).iter();
}
}
}
#[cfg(test)]
mod tests {
use bytes::Bytes;
use super::*;
use crate::config::PageServerConf;
use std::fs;
use std::str::FromStr;
use zenith_utils::zid::{ZTenantId, ZTimelineId};
fn repo_harness(test_name: &str) -> Result<(&'static PageServerConf, ZTenantId, ZTimelineId)> {
let repo_dir = PageServerConf::test_repo_dir(test_name);
let _ = fs::remove_dir_all(&repo_dir);
let conf = PageServerConf::dummy_conf(repo_dir);
// Make a static copy of the config. This can never be free'd, but that's
// OK in a test.
let conf: &'static PageServerConf = Box::leak(Box::new(conf));
let tenantid = ZTenantId::from_str("11000000000000000000000000000000").unwrap();
let timelineid = ZTimelineId::from_str("22000000000000000000000000000000").unwrap();
fs::create_dir_all(conf.timeline_path(&timelineid, &tenantid))?;
Ok((conf, tenantid, timelineid))
}
#[test]
fn test_ordered_iter() -> Result<()> {
let (conf, tenantid, timelineid) = repo_harness("test_ordered_iter")?;
let file = EphemeralFile::create(conf, tenantid, timelineid)?;
let mut page_versions = PageVersions::new(file);
const BLOCKS: u32 = 1000;
const LSNS: u64 = 50;
let empty_page = Bytes::from_static(&[0u8; 8192]);
let empty_page_version = PageVersion::Page(empty_page);
for blknum in 0..BLOCKS {
for lsn in 0..LSNS {
let old = page_versions.append_or_update_last(
blknum,
Lsn(lsn),
empty_page_version.clone(),
)?;
assert!(old.is_none());
}
}
let mut iter = page_versions.ordered_page_version_iter(None);
for blknum in 0..BLOCKS {
for lsn in 0..LSNS {
let (actual_blknum, actual_lsn, _pv) = iter.next().unwrap();
assert_eq!(actual_blknum, blknum);
assert_eq!(Lsn(lsn), actual_lsn);
}
}
assert!(iter.next().is_none());
assert!(iter.next().is_none()); // should be robust against excessive next() calls
const CUTOFF_LSN: Lsn = Lsn(30);
let mut iter = page_versions.ordered_page_version_iter(Some(CUTOFF_LSN));
for blknum in 0..BLOCKS {
for lsn in 0..CUTOFF_LSN.0 {
let (actual_blknum, actual_lsn, _pv) = iter.next().unwrap();
assert_eq!(actual_blknum, blknum);
assert_eq!(Lsn(lsn), actual_lsn);
}
}
assert!(iter.next().is_none());
assert!(iter.next().is_none()); // should be robust against excessive next() calls
Ok(())
}
}

View File

@@ -71,26 +71,15 @@ pub enum PageVersion {
}
///
/// Struct used to communicate across calls to 'get_page_reconstruct_data'.
/// Data needed to reconstruct a page version
///
/// Before first call to get_page_reconstruct_data, you can fill in 'page_img'
/// if you have an older cached version of the page available. That can save
/// work in 'get_page_reconstruct_data', as it can stop searching for page
/// versions when all the WAL records going back to the cached image have been
/// collected.
///
/// When get_page_reconstruct_data returns Complete, 'page_img' is set to an
/// image of the page, or the oldest WAL record in 'records' is a will_init-type
/// record that initializes the page without requiring a previous image.
///
/// If 'get_page_reconstruct_data' returns Continue, some 'records' may have
/// been collected, but there are more records outside the current layer. Pass
/// the same PageReconstructData struct in the next 'get_page_reconstruct_data'
/// call, to collect more records.
/// 'page_img' is the old base image of the page to start the WAL replay with.
/// It can be None, if the first WAL record initializes the page (will_init)
/// 'records' contains the records to apply over the base image.
///
pub struct PageReconstructData {
pub records: Vec<(Lsn, ZenithWalRecord)>,
pub page_img: Option<(Lsn, Bytes)>,
pub page_img: Option<Bytes>,
}
/// Return value from Layer::get_page_reconstruct_data
@@ -104,6 +93,8 @@ pub enum PageReconstructResult {
/// the returned LSN. This is usually considered an error, but might be OK
/// in some circumstances.
Missing(Lsn),
/// Use the cached image at `cached_img_lsn` as the base image
Cached,
}
///
@@ -147,16 +138,19 @@ pub trait Layer: Send + Sync {
/// It is up to the caller to collect more data from previous layer and
/// perform WAL redo, if necessary.
///
/// `cached_img_lsn` should be set to a cached page image's lsn < `lsn`.
/// This function will only return data after `cached_img_lsn`.
///
/// See PageReconstructResult for possible return values. The collected data
/// is appended to reconstruct_data; the caller should pass an empty struct
/// on first call, or a struct with a cached older image of the page if one
/// is available. If this returns PageReconstructResult::Continue, look up
/// the predecessor layer and call again with the same 'reconstruct_data' to
/// collect more data.
/// on first call. If this returns PageReconstructResult::Continue, look up
/// the predecessor layer and call again with the same 'reconstruct_data'
/// to collect more data.
fn get_page_reconstruct_data(
&self,
blknum: SegmentBlk,
lsn: Lsn,
cached_img_lsn: Option<Lsn>,
reconstruct_data: &mut PageReconstructData,
) -> Result<PageReconstructResult>;

View File

@@ -27,10 +27,13 @@ use zenith_utils::lsn::Lsn;
use zenith_utils::postgres_backend::is_socket_read_timed_out;
use zenith_utils::postgres_backend::PostgresBackend;
use zenith_utils::postgres_backend::{self, AuthType};
use zenith_utils::pq_proto::{BeMessage, FeMessage, RowDescriptor, SINGLE_COL_ROWDESC};
use zenith_utils::pq_proto::{
BeMessage, FeMessage, RowDescriptor, HELLO_WORLD_ROW, SINGLE_COL_ROWDESC,
};
use zenith_utils::zid::{ZTenantId, ZTimelineId};
use crate::basebackup;
use crate::branches;
use crate::config::PageServerConf;
use crate::relish::*;
use crate::repository::Timeline;
@@ -659,21 +662,79 @@ impl postgres_backend::Handler for PageServerHandler {
walreceiver::launch_wal_receiver(self.conf, tenantid, timelineid, &connstr)?;
pgb.write_message_noflush(&BeMessage::CommandComplete(b"SELECT 1"))?;
} else if query_string.starts_with("branch_create ") {
let err = || format!("invalid branch_create: '{}'", query_string);
// branch_create <tenantid> <branchname> <startpoint>
// TODO lazy static
// TODO: escaping, to allow branch names with spaces
let re = Regex::new(r"^branch_create ([[:xdigit:]]+) (\S+) ([^\r\n\s;]+)[\r\n\s;]*;?$")
.unwrap();
let caps = re.captures(query_string).with_context(err)?;
let tenantid = ZTenantId::from_str(caps.get(1).unwrap().as_str())?;
let branchname = caps.get(2).with_context(err)?.as_str().to_owned();
let startpoint_str = caps.get(3).with_context(err)?.as_str().to_owned();
self.check_permission(Some(tenantid))?;
let _enter =
info_span!("branch_create", name = %branchname, tenant = %tenantid).entered();
let branch =
branches::create_branch(self.conf, &branchname, &startpoint_str, &tenantid)?;
let branch = serde_json::to_vec(&branch)?;
pgb.write_message_noflush(&SINGLE_COL_ROWDESC)?
.write_message_noflush(&BeMessage::DataRow(&[Some(&branch)]))?
.write_message_noflush(&BeMessage::CommandComplete(b"SELECT 1"))?;
} else if query_string.starts_with("branch_list ") {
// branch_list <zenith tenantid as hex string>
let re = Regex::new(r"^branch_list ([[:xdigit:]]+)$").unwrap();
let caps = re
.captures(query_string)
.with_context(|| format!("invalid branch_list: '{}'", query_string))?;
let tenantid = ZTenantId::from_str(caps.get(1).unwrap().as_str())?;
// since these handlers for tenant/branch commands are deprecated (in favor of http based ones)
// just use false in place of include non incremental logical size
let branches = crate::branches::get_branches(self.conf, &tenantid, false)?;
let branches_buf = serde_json::to_vec(&branches)?;
pgb.write_message_noflush(&SINGLE_COL_ROWDESC)?
.write_message_noflush(&BeMessage::DataRow(&[Some(&branches_buf)]))?
.write_message_noflush(&BeMessage::CommandComplete(b"SELECT 1"))?;
} else if query_string.starts_with("tenant_list") {
let tenants = crate::tenant_mgr::list_tenants()?;
let tenants_buf = serde_json::to_vec(&tenants)?;
pgb.write_message_noflush(&SINGLE_COL_ROWDESC)?
.write_message_noflush(&BeMessage::DataRow(&[Some(&tenants_buf)]))?
.write_message_noflush(&BeMessage::CommandComplete(b"SELECT 1"))?;
} else if query_string.starts_with("tenant_create") {
let err = || format!("invalid tenant_create: '{}'", query_string);
// tenant_create <tenantid>
let re = Regex::new(r"^tenant_create ([[:xdigit:]]+)$").unwrap();
let caps = re.captures(query_string).with_context(err)?;
self.check_permission(None)?;
let tenantid = ZTenantId::from_str(caps.get(1).unwrap().as_str())?;
tenant_mgr::create_repository_for_tenant(self.conf, tenantid)?;
pgb.write_message_noflush(&SINGLE_COL_ROWDESC)?
.write_message_noflush(&BeMessage::CommandComplete(b"SELECT 1"))?;
} else if query_string.starts_with("status") {
pgb.write_message_noflush(&SINGLE_COL_ROWDESC)?
.write_message_noflush(&HELLO_WORLD_ROW)?
.write_message_noflush(&BeMessage::CommandComplete(b"SELECT 1"))?;
} else if query_string.to_ascii_lowercase().starts_with("set ") {
// important because psycopg2 executes "SET datestyle TO 'ISO'"
// on connect
pgb.write_message_noflush(&BeMessage::CommandComplete(b"SELECT 1"))?;
} else if query_string.starts_with("failpoints ") {
let (_, failpoints) = query_string.split_at("failpoints ".len());
for failpoint in failpoints.split(';') {
if let Some((name, actions)) = failpoint.split_once('=') {
info!("cfg failpoint: {} {}", name, actions);
fail::cfg(name, actions).unwrap();
} else {
bail!("Invalid failpoints format");
}
}
pgb.write_message_noflush(&BeMessage::CommandComplete(b"SELECT 1"))?;
} else if query_string.starts_with("do_gc ") {
// Run GC immediately on given timeline.
// FIXME: This is just for tests. See test_runner/batch_others/test_gc.py.

View File

@@ -447,6 +447,8 @@ pub mod repo_harness {
#[allow(clippy::bool_assert_comparison)]
#[cfg(test)]
mod tests {
use crate::layered_repository::metadata::METADATA_FILE_NAME;
use super::repo_harness::*;
use super::*;
use postgres_ffi::{pg_constants, xlog_utils::SIZEOF_CHECKPOINT};
@@ -744,8 +746,8 @@ mod tests {
let mut lsn = 0x10;
for blknum in 0..pg_constants::RELSEG_SIZE + 1 {
lsn += 0x10;
let img = TEST_IMG(&format!("foo blk {} at {}", blknum, Lsn(lsn)));
lsn += 0x10;
writer.put_page_image(TESTREL_A, blknum as BlockNumber, Lsn(lsn), img)?;
}
writer.advance_last_record_lsn(Lsn(lsn));
@@ -1130,4 +1132,141 @@ mod tests {
Ok(())
}
#[test]
fn corrupt_metadata() -> Result<()> {
const TEST_NAME: &str = "corrupt_metadata";
let harness = RepoHarness::create(TEST_NAME)?;
let repo = harness.load();
repo.create_empty_timeline(TIMELINE_ID, Lsn(0))?;
drop(repo);
let metadata_path = harness.timeline_path(&TIMELINE_ID).join(METADATA_FILE_NAME);
assert!(metadata_path.is_file());
let mut metadata_bytes = std::fs::read(&metadata_path)?;
assert_eq!(metadata_bytes.len(), 512);
metadata_bytes[512 - 4 - 2] ^= 1;
std::fs::write(metadata_path, metadata_bytes)?;
let new_repo = harness.load();
let err = new_repo.get_timeline(TIMELINE_ID).err().unwrap();
assert_eq!(err.to_string(), "failed to load metadata");
assert_eq!(
err.source().unwrap().to_string(),
"metadata checksum mismatch"
);
Ok(())
}
#[test]
fn future_layerfiles() -> Result<()> {
const TEST_NAME: &str = "future_layerfiles";
let harness = RepoHarness::create(TEST_NAME)?;
let repo = harness.load();
// Create a timeline with disk_consistent_lsn = 8000
let tline = repo.create_empty_timeline(TIMELINE_ID, Lsn(0x8000))?;
let writer = tline.writer();
writer.advance_last_record_lsn(Lsn(0x8000));
drop(writer);
repo.checkpoint_iteration(CheckpointConfig::Forced)?;
drop(repo);
let timeline_path = harness.timeline_path(&TIMELINE_ID);
let make_empty_file = |filename: &str| -> std::io::Result<()> {
let path = timeline_path.join(filename);
assert!(!path.exists());
std::fs::write(&path, &[])?;
Ok(())
};
// Helper function to check that a relation file exists, and a corresponding
// <filename>.0.old file does not.
let assert_exists = |filename: &str| {
let path = timeline_path.join(filename);
assert!(path.exists(), "file {} was removed", filename);
// Check that there is no .old file
let backup_path = timeline_path.join(format!("{}.0.old", filename));
assert!(
!backup_path.exists(),
"unexpected backup file {}",
backup_path.display()
);
};
// Helper function to check that a relation file does *not* exists, and a corresponding
// <filename>.<num>.old file does.
let assert_is_renamed = |filename: &str, num: u32| {
let path = timeline_path.join(filename);
assert!(
!path.exists(),
"file {} was not removed as expected",
filename
);
let backup_path = timeline_path.join(format!("{}.{}.old", filename, num));
assert!(
backup_path.exists(),
"backup file {} was not created",
backup_path.display()
);
};
// These files are considered to be in the future and will be renamed out
// of the way
let future_filenames = vec![
format!("pg_control_0_{:016X}", 0x8001),
format!("pg_control_0_{:016X}_{:016X}", 0x8001, 0x8008),
];
// But these are not:
let past_filenames = vec![
format!("pg_control_0_{:016X}", 0x8000),
format!("pg_control_0_{:016X}_{:016X}", 0x7000, 0x8001),
];
for filename in future_filenames.iter().chain(past_filenames.iter()) {
make_empty_file(filename)?;
}
// Load the timeline. This will cause the files in the "future" to be renamed
// away.
let new_repo = harness.load();
new_repo.get_timeline(TIMELINE_ID).unwrap();
drop(new_repo);
for filename in future_filenames.iter() {
assert_is_renamed(filename, 0);
}
for filename in past_filenames.iter() {
assert_exists(filename);
}
// Create the future files again, and load again. They should be renamed to
// *.1.old this time.
for filename in future_filenames.iter() {
make_empty_file(filename)?;
}
let new_repo = harness.load();
new_repo.get_timeline(TIMELINE_ID).unwrap();
drop(new_repo);
for filename in future_filenames.iter() {
assert_is_renamed(filename, 0);
assert_is_renamed(filename, 1);
}
for filename in past_filenames.iter() {
assert_exists(filename);
}
Ok(())
}
}

View File

@@ -37,6 +37,7 @@ use postgres_ffi::xlog_utils::*;
use postgres_ffi::TransactionId;
use postgres_ffi::{pg_constants, CheckPoint};
use zenith_utils::lsn::Lsn;
use zenith_utils::pg_checksum_page::pg_checksum_page;
static ZERO_PAGE: Bytes = Bytes::from_static(&[0u8; 8192]);
@@ -329,6 +330,9 @@ impl WalIngest {
}
image[0..4].copy_from_slice(&((lsn.0 >> 32) as u32).to_le_bytes());
image[4..8].copy_from_slice(&(lsn.0 as u32).to_le_bytes());
image[8..10].copy_from_slice(&[0u8; 2]);
let checksum = pg_checksum_page(&image, blk.blkno);
image[8..10].copy_from_slice(&checksum.to_le_bytes());
assert_eq!(image.len(), pg_constants::BLCKSZ as usize);
timeline.put_page_image(tag, blk.blkno, lsn, image.freeze())?;
} else {

View File

@@ -12,7 +12,6 @@ use crate::thread_mgr::ThreadKind;
use crate::walingest::WalIngest;
use anyhow::{bail, Context, Error, Result};
use bytes::BytesMut;
use fail::fail_point;
use lazy_static::lazy_static;
use postgres_ffi::waldecoder::*;
use postgres_protocol::message::backend::ReplicationMessage;
@@ -32,7 +31,6 @@ use zenith_utils::lsn::Lsn;
use zenith_utils::pq_proto::ZenithFeedback;
use zenith_utils::zid::ZTenantId;
use zenith_utils::zid::ZTimelineId;
//
// We keep one WAL Receiver active per timeline.
//
@@ -256,8 +254,6 @@ fn walreceiver_main(
let writer = timeline.writer();
walingest.ingest_record(writer.as_ref(), recdata, lsn)?;
fail_point!("walreceiver-after-ingest");
last_rec_lsn = lsn;
}

View File

@@ -268,11 +268,12 @@ impl XlXactParsedRecord {
let info = xl_info & pg_constants::XLOG_XACT_OPMASK;
// The record starts with time of commit/abort
let xact_time = buf.get_i64_le();
let xinfo = if xl_info & pg_constants::XLOG_XACT_HAS_INFO != 0 {
buf.get_u32_le()
let xinfo;
if xl_info & pg_constants::XLOG_XACT_HAS_INFO != 0 {
xinfo = buf.get_u32_le();
} else {
0
};
xinfo = 0;
}
let db_id;
let ts_id;
if xinfo & pg_constants::XACT_XINFO_HAS_DBINFO != 0 {
@@ -501,6 +502,7 @@ pub fn decode_wal_record(record: Bytes) -> DecodedWALRecord {
0..=pg_constants::XLR_MAX_BLOCK_ID => {
/* XLogRecordBlockHeader */
let mut blk = DecodedBkpBlock::new();
let fork_flags: u8;
if block_id <= max_block_id {
// TODO
@@ -513,7 +515,7 @@ pub fn decode_wal_record(record: Bytes) -> DecodedWALRecord {
}
max_block_id = block_id;
let fork_flags: u8 = buf.get_u8();
fork_flags = buf.get_u8();
blk.forknum = fork_flags & pg_constants::BKPBLOCK_FORK_MASK;
blk.flags = fork_flags;
blk.has_image = (fork_flags & pg_constants::BKPBLOCK_HAS_IMAGE) != 0;

View File

@@ -102,6 +102,8 @@ impl crate::walredo::WalRedoManager for DummyRedoManager {
}
}
static TIMEOUT: Duration = Duration::from_secs(20);
// Metrics collected on WAL redo operations
//
// We collect the time spent in actual WAL redo ('redo'), and time waiting
@@ -219,14 +221,7 @@ impl WalRedoManager for PostgresRedoManager {
let result = if batch_zenith {
self.apply_batch_zenith(rel, blknum, lsn, img, &records[batch_start..i])
} else {
self.apply_batch_postgres(
rel,
blknum,
lsn,
img,
&records[batch_start..i],
self.conf.wal_redo_timeout,
)
self.apply_batch_postgres(rel, blknum, lsn, img, &records[batch_start..i])
};
img = Some(result?);
@@ -238,14 +233,7 @@ impl WalRedoManager for PostgresRedoManager {
if batch_zenith {
self.apply_batch_zenith(rel, blknum, lsn, img, &records[batch_start..])
} else {
self.apply_batch_postgres(
rel,
blknum,
lsn,
img,
&records[batch_start..],
self.conf.wal_redo_timeout,
)
self.apply_batch_postgres(rel, blknum, lsn, img, &records[batch_start..])
}
}
}
@@ -273,7 +261,6 @@ impl PostgresRedoManager {
lsn: Lsn,
base_img: Option<Bytes>,
records: &[(Lsn, ZenithWalRecord)],
wal_redo_timeout: Duration,
) -> Result<Bytes, WalRedoError> {
let start_time = Instant::now();
@@ -294,7 +281,7 @@ impl PostgresRedoManager {
let result = if let RelishTag::Relation(rel) = rel {
// Relational WAL records are applied using wal-redo-postgres
let buf_tag = BufferTag { rel, blknum };
apply_result = process.apply_wal_records(buf_tag, base_img, records, wal_redo_timeout);
apply_result = process.apply_wal_records(buf_tag, base_img, records);
apply_result.map_err(WalRedoError::IoError)
} else {
@@ -616,7 +603,6 @@ impl PostgresRedoProcess {
tag: BufferTag,
base_img: Option<Bytes>,
records: &[(Lsn, ZenithWalRecord)],
wal_redo_timeout: Duration,
) -> Result<Bytes, std::io::Error> {
// Serialize all the messages to send the WAL redo process first.
//
@@ -667,7 +653,7 @@ impl PostgresRedoProcess {
// If we have more data to write, wake up if 'stdin' becomes writeable or
// we have data to read. Otherwise only wake up if there's data to read.
let nfds = if nwrite < writebuf.len() { 3 } else { 2 };
let n = nix::poll::poll(&mut pollfds[0..nfds], wal_redo_timeout.as_millis() as i32)?;
let n = nix::poll::poll(&mut pollfds[0..nfds], TIMEOUT.as_millis() as i32)?;
if n == 0 {
return Err(Error::new(ErrorKind::Other, "WAL redo timed out"));

25
poetry.lock generated
View File

@@ -91,14 +91,6 @@ botocore = ">=1.11.3"
future = "*"
wrapt = "*"
[[package]]
name = "backoff"
version = "1.11.1"
description = "Function decoration for backoff and retry"
category = "main"
optional = false
python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*, !=3.4.*"
[[package]]
name = "boto3"
version = "1.20.40"
@@ -822,11 +814,11 @@ python-versions = "*"
[[package]]
name = "moto"
version = "3.0.4"
version = "3.0.0"
description = "A library that allows your python tests to easily mock out the boto library"
category = "main"
optional = false
python-versions = ">=3.6"
python-versions = "*"
[package.dependencies]
aws-xray-sdk = {version = ">=0.93,<0.96 || >0.96", optional = true, markers = "extra == \"server\""}
@@ -856,8 +848,7 @@ xmltodict = "*"
[package.extras]
all = ["PyYAML (>=5.1)", "python-jose[cryptography] (>=3.1.0,<4.0.0)", "ecdsa (!=0.15)", "docker (>=2.5.1)", "graphql-core", "jsondiff (>=1.1.2)", "aws-xray-sdk (>=0.93,!=0.96)", "idna (>=2.5,<4)", "cfn-lint (>=0.4.0)", "sshpubkeys (>=3.1.0)", "setuptools"]
apigateway = ["PyYAML (>=5.1)", "python-jose[cryptography] (>=3.1.0,<4.0.0)", "ecdsa (!=0.15)"]
apigatewayv2 = ["PyYAML (>=5.1)"]
apigateway = ["python-jose[cryptography] (>=3.1.0,<4.0.0)", "ecdsa (!=0.15)"]
appsync = ["graphql-core"]
awslambda = ["docker (>=2.5.1)"]
batch = ["docker (>=2.5.1)"]
@@ -1361,7 +1352,7 @@ testing = ["pytest (>=6)", "pytest-checkdocs (>=2.4)", "pytest-flake8", "pytest-
[metadata]
lock-version = "1.1"
python-versions = "^3.7"
content-hash = "58762accad4122026c650fa43421a900546e89f9908e2268410e7b11cc8c6c4e"
content-hash = "0fa6c9377fbc827240d18d8b7e3742def37e90fc3277fddf8525d82dabd13090"
[metadata.files]
aiopg = [
@@ -1404,10 +1395,6 @@ aws-xray-sdk = [
{file = "aws-xray-sdk-2.9.0.tar.gz", hash = "sha256:b0cd972db218d4d8f7b53ad806fc6184626b924c4997ae58fc9f2a8cd1281568"},
{file = "aws_xray_sdk-2.9.0-py2.py3-none-any.whl", hash = "sha256:98216b3ac8281b51b59a8703f8ec561c460807d9d0679838f5c0179d381d7e58"},
]
backoff = [
{file = "backoff-1.11.1-py2.py3-none-any.whl", hash = "sha256:61928f8fa48d52e4faa81875eecf308eccfb1016b018bb6bd21e05b5d90a96c5"},
{file = "backoff-1.11.1.tar.gz", hash = "sha256:ccb962a2378418c667b3c979b504fdeb7d9e0d29c0579e3b13b86467177728cb"},
]
boto3 = [
{file = "boto3-1.20.40-py3-none-any.whl", hash = "sha256:cfe85589e4a0a997c7b9ae7432400b03fa6fa5fea29fdc48db3099a903b76998"},
{file = "boto3-1.20.40.tar.gz", hash = "sha256:66aef9a6d8cad393f69166112ba49e14e2c6766f9278c96134101314a9af2992"},
@@ -1679,8 +1666,8 @@ mccabe = [
{file = "mccabe-0.6.1.tar.gz", hash = "sha256:dd8d182285a0fe56bace7f45b5e7d1a6ebcbf524e8f3bd87eb0f125271b8831f"},
]
moto = [
{file = "moto-3.0.4-py2.py3-none-any.whl", hash = "sha256:79646213d8438385182f4eea79e28725f94b3d0d3dc9a3eda81db47e0ebef6cc"},
{file = "moto-3.0.4.tar.gz", hash = "sha256:168b8a3cb4dd8a6df8e51d582761cefa9657b9f45ac7e1eb24dae394ebc9e000"},
{file = "moto-3.0.0-py2.py3-none-any.whl", hash = "sha256:762d33bbad3642c687f6495e69331318bef43f9aa662174397706ec3ad2a3578"},
{file = "moto-3.0.0.tar.gz", hash = "sha256:d6b00a2663290e7ebb06823d5ffcb124c8dc9bf526b878539ef7c4a377fd8255"},
]
mypy = [
{file = "mypy-0.910-cp35-cp35m-macosx_10_9_x86_64.whl", hash = "sha256:a155d80ea6cee511a3694b108c4494a39f42de11ee4e61e72bc424c490e46457"},

View File

@@ -6,28 +6,18 @@ edition = "2021"
[dependencies]
anyhow = "1.0"
bytes = { version = "1.0.1", features = ['serde'] }
clap = "3.0"
futures = "0.3.13"
hashbrown = "0.11.2"
hex = "0.4.3"
hyper = "0.14"
lazy_static = "1.4.0"
md5 = "0.7.0"
parking_lot = "0.11.2"
pin-project-lite = "0.2.7"
rand = "0.8.3"
reqwest = { version = "0.11", default-features = false, features = ["blocking", "json", "rustls-tls"] }
rustls = "0.19.1"
scopeguard = "1.1.0"
hex = "0.4.3"
hyper = "0.14"
serde = "1"
serde_json = "1"
tokio = { version = "1.11", features = ["macros"] }
tokio-postgres = { git = "https://github.com/zenithdb/rust-postgres.git", rev="2949d98df52587d562986aad155dd4e889e408b7" }
tokio-rustls = "0.22.0"
clap = "3.0"
rustls = "0.19.1"
reqwest = { version = "0.11", default-features = false, features = ["blocking", "json", "rustls-tls"] }
zenith_utils = { path = "../zenith_utils" }
zenith_metrics = { path = "../zenith_metrics" }
[dev-dependencies]
tokio-postgres-rustls = "0.8.0"
rcgen = "0.8.14"

View File

@@ -1,169 +0,0 @@
use crate::compute::DatabaseInfo;
use crate::config::ProxyConfig;
use crate::cplane_api::{self, CPlaneApi};
use crate::stream::PqStream;
use anyhow::{anyhow, bail, Context};
use std::collections::HashMap;
use tokio::io::{AsyncRead, AsyncWrite};
use zenith_utils::pq_proto::{BeMessage as Be, BeParameterStatusMessage, FeMessage as Fe};
/// Various client credentials which we use for authentication.
#[derive(Debug, PartialEq, Eq)]
pub struct ClientCredentials {
pub user: String,
pub dbname: String,
}
impl TryFrom<HashMap<String, String>> for ClientCredentials {
type Error = anyhow::Error;
fn try_from(mut value: HashMap<String, String>) -> Result<Self, Self::Error> {
let mut get_param = |key| {
value
.remove(key)
.with_context(|| format!("{} is missing in startup packet", key))
};
let user = get_param("user")?;
let db = get_param("database")?;
Ok(Self { user, dbname: db })
}
}
impl ClientCredentials {
/// Use credentials to authenticate the user.
pub async fn authenticate(
self,
config: &ProxyConfig,
client: &mut PqStream<impl AsyncRead + AsyncWrite + Unpin>,
) -> anyhow::Result<DatabaseInfo> {
use crate::config::ClientAuthMethod::*;
use crate::config::RouterConfig::*;
let db_info = match &config.router_config {
Static { host, port } => handle_static(host.clone(), *port, client, self).await,
Dynamic(Mixed) => {
if self.user.ends_with("@zenith") {
handle_existing_user(config, client, self).await
} else {
handle_new_user(config, client).await
}
}
Dynamic(Password) => handle_existing_user(config, client, self).await,
Dynamic(Link) => handle_new_user(config, client).await,
};
db_info.context("failed to authenticate client")
}
}
fn new_psql_session_id() -> String {
hex::encode(rand::random::<[u8; 8]>())
}
async fn handle_static(
host: String,
port: u16,
client: &mut PqStream<impl AsyncRead + AsyncWrite + Unpin>,
creds: ClientCredentials,
) -> anyhow::Result<DatabaseInfo> {
client
.write_message(&Be::AuthenticationCleartextPassword)
.await?;
// Read client's password bytes
let msg = match client.read_message().await? {
Fe::PasswordMessage(msg) => msg,
bad => bail!("unexpected message type: {:?}", bad),
};
let cleartext_password = std::str::from_utf8(&msg)?.split('\0').next().unwrap();
let db_info = DatabaseInfo {
host,
port,
dbname: creds.dbname.clone(),
user: creds.user.clone(),
password: Some(cleartext_password.into()),
};
client
.write_message_noflush(&Be::AuthenticationOk)?
.write_message_noflush(&BeParameterStatusMessage::encoding())?;
Ok(db_info)
}
async fn handle_existing_user(
config: &ProxyConfig,
client: &mut PqStream<impl AsyncRead + AsyncWrite + Unpin>,
creds: ClientCredentials,
) -> anyhow::Result<DatabaseInfo> {
let psql_session_id = new_psql_session_id();
let md5_salt = rand::random();
client
.write_message(&Be::AuthenticationMD5Password(&md5_salt))
.await?;
// Read client's password hash
let msg = match client.read_message().await? {
Fe::PasswordMessage(msg) => msg,
bad => bail!("unexpected message type: {:?}", bad),
};
let (_trailing_null, md5_response) = msg
.split_last()
.ok_or_else(|| anyhow!("unexpected password message"))?;
let cplane = CPlaneApi::new(&config.auth_endpoint);
let db_info = cplane
.authenticate_proxy_request(creds, md5_response, &md5_salt, &psql_session_id)
.await?;
client
.write_message_noflush(&Be::AuthenticationOk)?
.write_message_noflush(&BeParameterStatusMessage::encoding())?;
Ok(db_info)
}
async fn handle_new_user(
config: &ProxyConfig,
client: &mut PqStream<impl AsyncRead + AsyncWrite + Unpin>,
) -> anyhow::Result<DatabaseInfo> {
let psql_session_id = new_psql_session_id();
let greeting = hello_message(&config.redirect_uri, &psql_session_id);
let db_info = cplane_api::with_waiter(psql_session_id, |waiter| async {
// Give user a URL to spawn a new database
client
.write_message_noflush(&Be::AuthenticationOk)?
.write_message_noflush(&BeParameterStatusMessage::encoding())?
.write_message(&Be::NoticeResponse(greeting))
.await?;
// Wait for web console response
waiter.await?.map_err(|e| anyhow!(e))
})
.await?;
client.write_message_noflush(&Be::NoticeResponse("Connecting to database.".into()))?;
Ok(db_info)
}
fn hello_message(redirect_uri: &str, session_id: &str) -> String {
format!(
concat![
"☀️ Welcome to Zenith!\n",
"To proceed with database creation, open the following link:\n\n",
" {redirect_uri}{session_id}\n\n",
"It needs to be done once and we will send you '.pgpass' file,\n",
"which will allow you to access or create ",
"databases without opening your web browser."
],
redirect_uri = redirect_uri,
session_id = session_id,
)
}

View File

@@ -1,106 +0,0 @@
use anyhow::{anyhow, Context};
use hashbrown::HashMap;
use parking_lot::Mutex;
use std::net::SocketAddr;
use tokio::net::TcpStream;
use tokio_postgres::{CancelToken, NoTls};
use zenith_utils::pq_proto::CancelKeyData;
/// Enables serving CancelRequests.
#[derive(Default)]
pub struct CancelMap(Mutex<HashMap<CancelKeyData, Option<CancelClosure>>>);
impl CancelMap {
/// Cancel a running query for the corresponding connection.
pub async fn cancel_session(&self, key: CancelKeyData) -> anyhow::Result<()> {
let cancel_closure = self
.0
.lock()
.get(&key)
.and_then(|x| x.clone())
.with_context(|| format!("unknown session: {:?}", key))?;
cancel_closure.try_cancel_query().await
}
/// Run async action within an ephemeral session identified by [`CancelKeyData`].
pub async fn with_session<'a, F, R, V>(&'a self, f: F) -> anyhow::Result<V>
where
F: FnOnce(Session<'a>) -> R,
R: std::future::Future<Output = anyhow::Result<V>>,
{
// HACK: We'd rather get the real backend_pid but tokio_postgres doesn't
// expose it and we don't want to do another roundtrip to query
// for it. The client will be able to notice that this is not the
// actual backend_pid, but backend_pid is not used for anything
// so it doesn't matter.
let key = rand::random();
// Random key collisions are unlikely to happen here, but they're still possible,
// which is why we have to take care not to rewrite an existing key.
self.0
.lock()
.try_insert(key, None)
.map_err(|_| anyhow!("session already exists: {:?}", key))?;
// This will guarantee that the session gets dropped
// as soon as the future is finished.
scopeguard::defer! {
self.0.lock().remove(&key);
}
let session = Session::new(key, self);
f(session).await
}
}
/// This should've been a [`std::future::Future`], but
/// it's impossible to name a type of an unboxed future
/// (we'd need something like `#![feature(type_alias_impl_trait)]`).
#[derive(Clone)]
pub struct CancelClosure {
socket_addr: SocketAddr,
cancel_token: CancelToken,
}
impl CancelClosure {
pub fn new(socket_addr: SocketAddr, cancel_token: CancelToken) -> Self {
Self {
socket_addr,
cancel_token,
}
}
/// Cancels the query running on user's compute node.
pub async fn try_cancel_query(self) -> anyhow::Result<()> {
let socket = TcpStream::connect(self.socket_addr).await?;
self.cancel_token.cancel_query_raw(socket, NoTls).await?;
Ok(())
}
}
/// Helper for registering query cancellation tokens.
pub struct Session<'a> {
/// The user-facing key identifying this session.
key: CancelKeyData,
/// The [`CancelMap`] this session belongs to.
cancel_map: &'a CancelMap,
}
impl<'a> Session<'a> {
fn new(key: CancelKeyData, cancel_map: &'a CancelMap) -> Self {
Self { key, cancel_map }
}
/// Store the cancel token for the given session.
/// This enables query cancellation in [`crate::proxy::handshake`].
pub fn enable_cancellation(self, cancel_closure: CancelClosure) -> CancelKeyData {
self.cancel_map
.0
.lock()
.insert(self.key, Some(cancel_closure));
self.key
}
}

View File

@@ -1,42 +0,0 @@
use anyhow::Context;
use serde::{Deserialize, Serialize};
use std::net::{SocketAddr, ToSocketAddrs};
/// Compute node connection params.
#[derive(Serialize, Deserialize, Debug, Default)]
pub struct DatabaseInfo {
pub host: String,
pub port: u16,
pub dbname: String,
pub user: String,
pub password: Option<String>,
}
impl DatabaseInfo {
pub fn socket_addr(&self) -> anyhow::Result<SocketAddr> {
let host_port = format!("{}:{}", self.host, self.port);
host_port
.to_socket_addrs()
.with_context(|| format!("cannot resolve {} to SocketAddr", host_port))?
.next()
.context("cannot resolve at least one SocketAddr")
}
}
impl From<DatabaseInfo> for tokio_postgres::Config {
fn from(db_info: DatabaseInfo) -> Self {
let mut config = tokio_postgres::Config::new();
config
.host(&db_info.host)
.port(db_info.port)
.dbname(&db_info.dbname)
.user(&db_info.user);
if let Some(password) = db_info.password {
config.password(password);
}
config
}
}

View File

@@ -1,79 +1,18 @@
use crate::auth::ClientCredentials;
use crate::compute::DatabaseInfo;
use crate::waiters::{Waiter, Waiters};
use anyhow::{anyhow, bail};
use lazy_static::lazy_static;
use anyhow::{anyhow, bail, Context};
use serde::{Deserialize, Serialize};
use std::net::{SocketAddr, ToSocketAddrs};
lazy_static! {
static ref CPLANE_WAITERS: Waiters<Result<DatabaseInfo, String>> = Default::default();
use crate::state::ProxyWaiters;
#[derive(Serialize, Deserialize, Debug, Default)]
pub struct DatabaseInfo {
pub host: String,
pub port: u16,
pub dbname: String,
pub user: String,
pub password: Option<String>,
}
/// Give caller an opportunity to wait for cplane's reply.
pub async fn with_waiter<F, R, T>(psql_session_id: impl Into<String>, f: F) -> anyhow::Result<T>
where
F: FnOnce(Waiter<'static, Result<DatabaseInfo, String>>) -> R,
R: std::future::Future<Output = anyhow::Result<T>>,
{
let waiter = CPLANE_WAITERS.register(psql_session_id.into())?;
f(waiter).await
}
pub fn notify(psql_session_id: &str, msg: Result<DatabaseInfo, String>) -> anyhow::Result<()> {
CPLANE_WAITERS.notify(psql_session_id, msg)
}
/// Zenith console API wrapper.
pub struct CPlaneApi<'a> {
auth_endpoint: &'a str,
}
impl<'a> CPlaneApi<'a> {
pub fn new(auth_endpoint: &'a str) -> Self {
Self { auth_endpoint }
}
}
impl CPlaneApi<'_> {
pub async fn authenticate_proxy_request(
&self,
creds: ClientCredentials,
md5_response: &[u8],
salt: &[u8; 4],
psql_session_id: &str,
) -> anyhow::Result<DatabaseInfo> {
let mut url = reqwest::Url::parse(self.auth_endpoint)?;
url.query_pairs_mut()
.append_pair("login", &creds.user)
.append_pair("database", &creds.dbname)
.append_pair("md5response", std::str::from_utf8(md5_response)?)
.append_pair("salt", &hex::encode(salt))
.append_pair("psql_session_id", psql_session_id);
with_waiter(psql_session_id, |waiter| async {
println!("cplane request: {}", url);
// TODO: leverage `reqwest::Client` to reuse connections
let resp = reqwest::get(url).await?;
if !resp.status().is_success() {
bail!("Auth failed: {}", resp.status())
}
let auth_info: ProxyAuthResponse = serde_json::from_str(resp.text().await?.as_str())?;
println!("got auth info: #{:?}", auth_info);
use ProxyAuthResponse::*;
match auth_info {
Ready { conn_info } => Ok(conn_info),
Error { error } => bail!(error),
NotReady { .. } => waiter.await?.map_err(|e| anyhow!(e)),
}
})
.await
}
}
// NOTE: the order of constructors is important.
// https://serde.rs/enum-representations.html#untagged
#[derive(Serialize, Deserialize, Debug)]
#[serde(untagged)]
enum ProxyAuthResponse {
@@ -82,6 +21,86 @@ enum ProxyAuthResponse {
NotReady { ready: bool }, // TODO: get rid of `ready`
}
impl DatabaseInfo {
pub fn socket_addr(&self) -> anyhow::Result<SocketAddr> {
let host_port = format!("{}:{}", self.host, self.port);
host_port
.to_socket_addrs()
.with_context(|| format!("cannot resolve {} to SocketAddr", host_port))?
.next()
.context("cannot resolve at least one SocketAddr")
}
}
impl From<DatabaseInfo> for tokio_postgres::Config {
fn from(db_info: DatabaseInfo) -> Self {
let mut config = tokio_postgres::Config::new();
config
.host(&db_info.host)
.port(db_info.port)
.dbname(&db_info.dbname)
.user(&db_info.user);
if let Some(password) = db_info.password {
config.password(password);
}
config
}
}
pub struct CPlaneApi<'a> {
auth_endpoint: &'a str,
waiters: &'a ProxyWaiters,
}
impl<'a> CPlaneApi<'a> {
pub fn new(auth_endpoint: &'a str, waiters: &'a ProxyWaiters) -> Self {
Self {
auth_endpoint,
waiters,
}
}
}
impl CPlaneApi<'_> {
pub fn authenticate_proxy_request(
&self,
user: &str,
database: &str,
md5_response: &[u8],
salt: &[u8; 4],
psql_session_id: &str,
) -> anyhow::Result<DatabaseInfo> {
let mut url = reqwest::Url::parse(self.auth_endpoint)?;
url.query_pairs_mut()
.append_pair("login", user)
.append_pair("database", database)
.append_pair("md5response", std::str::from_utf8(md5_response)?)
.append_pair("salt", &hex::encode(salt))
.append_pair("psql_session_id", psql_session_id);
let waiter = self.waiters.register(psql_session_id.to_owned());
println!("cplane request: {}", url);
let resp = reqwest::blocking::get(url)?;
if !resp.status().is_success() {
bail!("Auth failed: {}", resp.status())
}
let auth_info: ProxyAuthResponse = serde_json::from_str(resp.text()?.as_str())?;
println!("got auth info: #{:?}", auth_info);
use ProxyAuthResponse::*;
match auth_info {
Ready { conn_info } => Ok(conn_info),
Error { error } => bail!(error),
NotReady { .. } => waiter.wait()?.map_err(|e| anyhow!(e)),
}
}
}
#[cfg(test)]
mod tests {
use super::*;

View File

@@ -1,30 +1,15 @@
use anyhow::anyhow;
use hyper::{Body, Request, Response, StatusCode};
use std::net::TcpListener;
use zenith_utils::http::RouterBuilder;
use zenith_utils::http::endpoint;
use zenith_utils::http::error::ApiError;
use zenith_utils::http::json::json_response;
use zenith_utils::http::{RouterBuilder, RouterService};
async fn status_handler(_: Request<Body>) -> Result<Response<Body>, ApiError> {
Ok(json_response(StatusCode::OK, "")?)
}
fn make_router() -> RouterBuilder<hyper::Body, ApiError> {
pub fn make_router() -> RouterBuilder<hyper::Body, ApiError> {
let router = endpoint::make_router();
router.get("/v1/status", status_handler)
}
pub async fn thread_main(http_listener: TcpListener) -> anyhow::Result<()> {
scopeguard::defer! {
println!("http has shut down");
}
let service = || RouterService::new(make_router().build()?);
hyper::Server::from_tcp(http_listener)?
.serve(service().map_err(|e| anyhow!(e))?)
.await?;
Ok(())
}

View File

@@ -5,36 +5,21 @@
/// (control plane API in our case) and can create new databases and accounts
/// in somewhat transparent manner (again via communication with control plane API).
///
use anyhow::{bail, Context};
use anyhow::bail;
use clap::{App, Arg};
use config::ProxyConfig;
use futures::FutureExt;
use std::future::Future;
use tokio::{net::TcpListener, task::JoinError};
use zenith_utils::GIT_VERSION;
use state::{ProxyConfig, ProxyState};
use std::thread;
use zenith_utils::http::endpoint;
use zenith_utils::{tcp_listener, GIT_VERSION};
use crate::config::{ClientAuthMethod, RouterConfig};
mod auth;
mod cancellation;
mod compute;
mod config;
mod cplane_api;
mod http;
mod mgmt;
mod proxy;
mod stream;
mod state;
mod waiters;
/// Flattens Result<Result<T>> into Result<T>.
async fn flatten_err(
f: impl Future<Output = Result<anyhow::Result<()>, JoinError>>,
) -> anyhow::Result<()> {
f.map(|r| r.context("join error").and_then(|x| x)).await
}
#[tokio::main]
async fn main() -> anyhow::Result<()> {
fn main() -> anyhow::Result<()> {
zenith_metrics::set_common_metrics_prefix("zenith_proxy");
let arg_matches = App::new("Zenith proxy/router")
.version(GIT_VERSION)
@@ -46,20 +31,6 @@ async fn main() -> anyhow::Result<()> {
.help("listen for incoming client connections on ip:port")
.default_value("127.0.0.1:4432"),
)
.arg(
Arg::new("auth-method")
.long("auth-method")
.takes_value(true)
.help("Possible values: password | link | mixed")
.default_value("mixed"),
)
.arg(
Arg::new("static-router")
.short('s')
.long("static-router")
.takes_value(true)
.help("Route all clients to host:port"),
)
.arg(
Arg::new("mgmt")
.short('m')
@@ -108,59 +79,63 @@ async fn main() -> anyhow::Result<()> {
)
.get_matches();
let tls_config = match (
let ssl_config = match (
arg_matches.value_of("ssl-key"),
arg_matches.value_of("ssl-cert"),
) {
(Some(key_path), Some(cert_path)) => Some(config::configure_ssl(key_path, cert_path)?),
(Some(key_path), Some(cert_path)) => {
Some(crate::state::configure_ssl(key_path, cert_path)?)
}
(None, None) => None,
_ => bail!("either both or neither ssl-key and ssl-cert must be specified"),
};
let auth_method = arg_matches.value_of("auth-method").unwrap().parse()?;
let router_config = match arg_matches.value_of("static-router") {
None => RouterConfig::Dynamic(auth_method),
Some(addr) => {
if let ClientAuthMethod::Password = auth_method {
let (host, port) = addr.split_once(':').unwrap();
RouterConfig::Static {
host: host.to_string(),
port: port.parse().unwrap(),
}
} else {
bail!("static-router requires --auth-method password")
}
}
};
let config: &ProxyConfig = Box::leak(Box::new(ProxyConfig {
router_config,
let config = ProxyConfig {
proxy_address: arg_matches.value_of("proxy").unwrap().parse()?,
mgmt_address: arg_matches.value_of("mgmt").unwrap().parse()?,
http_address: arg_matches.value_of("http").unwrap().parse()?,
redirect_uri: arg_matches.value_of("uri").unwrap().parse()?,
auth_endpoint: arg_matches.value_of("auth-endpoint").unwrap().parse()?,
tls_config,
}));
ssl_config,
};
let state: &ProxyState = Box::leak(Box::new(ProxyState::new(config)));
println!("Version: {}", GIT_VERSION);
// Check that we can bind to address before further initialization
println!("Starting http on {}", config.http_address);
let http_listener = TcpListener::bind(config.http_address).await?.into_std()?;
println!("Starting http on {}", state.conf.http_address);
let http_listener = tcp_listener::bind(state.conf.http_address)?;
println!("Starting mgmt on {}", config.mgmt_address);
let mgmt_listener = TcpListener::bind(config.mgmt_address).await?.into_std()?;
println!("Starting proxy on {}", state.conf.proxy_address);
let pageserver_listener = tcp_listener::bind(state.conf.proxy_address)?;
println!("Starting proxy on {}", config.proxy_address);
let proxy_listener = TcpListener::bind(config.proxy_address).await?;
println!("Starting mgmt on {}", state.conf.mgmt_address);
let mgmt_listener = tcp_listener::bind(state.conf.mgmt_address)?;
let http = tokio::spawn(http::thread_main(http_listener));
let proxy = tokio::spawn(proxy::thread_main(config, proxy_listener));
let mgmt = tokio::task::spawn_blocking(move || mgmt::thread_main(mgmt_listener));
let threads = [
thread::Builder::new()
.name("Http thread".into())
.spawn(move || {
let router = http::make_router();
endpoint::serve_thread_main(
router,
http_listener,
std::future::pending(), // never shut down
)
})?,
// Spawn a thread to listen for connections. It will spawn further threads
// for each connection.
thread::Builder::new()
.name("Listener thread".into())
.spawn(move || proxy::thread_main(state, pageserver_listener))?,
thread::Builder::new()
.name("Mgmt thread".into())
.spawn(move || mgmt::thread_main(state, mgmt_listener))?,
];
let tasks = [flatten_err(http), flatten_err(proxy), flatten_err(mgmt)];
let _: Vec<()> = futures::future::try_join_all(tasks).await?;
for t in threads {
t.join().unwrap()?;
}
Ok(())
}

View File

@@ -1,49 +1,44 @@
use crate::{compute::DatabaseInfo, cplane_api};
use anyhow::Context;
use serde::Deserialize;
use std::{
net::{TcpListener, TcpStream},
thread,
};
use serde::Deserialize;
use zenith_utils::{
postgres_backend::{self, AuthType, PostgresBackend},
pq_proto::{BeMessage, SINGLE_COL_ROWDESC},
};
use crate::{cplane_api::DatabaseInfo, ProxyState};
///
/// Main proxy listener loop.
///
/// Listens for connections, and launches a new handler thread for each.
///
pub fn thread_main(listener: TcpListener) -> anyhow::Result<()> {
scopeguard::defer! {
println!("mgmt has shut down");
}
listener
.set_nonblocking(false)
.context("failed to set listener to blocking")?;
pub fn thread_main(state: &'static ProxyState, listener: TcpListener) -> anyhow::Result<()> {
loop {
let (socket, peer_addr) = listener.accept().context("failed to accept a new client")?;
let (socket, peer_addr) = listener.accept()?;
println!("accepted connection from {}", peer_addr);
socket
.set_nodelay(true)
.context("failed to set client socket option")?;
socket.set_nodelay(true).unwrap();
thread::spawn(move || {
if let Err(err) = handle_connection(socket) {
if let Err(err) = handle_connection(state, socket) {
println!("error: {}", err);
}
});
}
}
fn handle_connection(socket: TcpStream) -> anyhow::Result<()> {
fn handle_connection(state: &ProxyState, socket: TcpStream) -> anyhow::Result<()> {
let mut conn_handler = MgmtHandler { state };
let pgbackend = PostgresBackend::new(socket, AuthType::Trust, None, true)?;
pgbackend.run(&mut MgmtHandler)
pgbackend.run(&mut conn_handler)
}
struct MgmtHandler;
struct MgmtHandler<'a> {
state: &'a ProxyState,
}
/// Serialized examples:
// {
@@ -79,13 +74,13 @@ enum PsqlSessionResult {
Failure(String),
}
impl postgres_backend::Handler for MgmtHandler {
impl postgres_backend::Handler for MgmtHandler<'_> {
fn process_query(
&mut self,
pgb: &mut PostgresBackend,
query_string: &str,
) -> anyhow::Result<()> {
let res = try_process_query(pgb, query_string);
let res = try_process_query(self, pgb, query_string);
// intercept and log error message
if res.is_err() {
println!("Mgmt query failed: #{:?}", res);
@@ -94,7 +89,11 @@ impl postgres_backend::Handler for MgmtHandler {
}
}
fn try_process_query(pgb: &mut PostgresBackend, query_string: &str) -> anyhow::Result<()> {
fn try_process_query(
mgmt: &mut MgmtHandler,
pgb: &mut PostgresBackend,
query_string: &str,
) -> anyhow::Result<()> {
println!("Got mgmt query: '{}'", query_string);
let resp: PsqlSessionResponse = serde_json::from_str(query_string)?;
@@ -105,7 +104,7 @@ fn try_process_query(pgb: &mut PostgresBackend, query_string: &str) -> anyhow::R
Failure(message) => Err(message),
};
match cplane_api::notify(&resp.session_id, msg) {
match mgmt.state.waiters.notify(&resp.session_id, msg) {
Ok(()) => {
pgb.write_message_noflush(&SINGLE_COL_ROWDESC)?
.write_message_noflush(&BeMessage::DataRow(&[Some(b"ok")]))?

View File

@@ -1,332 +1,389 @@
use crate::auth;
use crate::cancellation::{self, CancelClosure, CancelMap};
use crate::compute::DatabaseInfo;
use crate::config::{ProxyConfig, TlsConfig};
use crate::stream::{MetricsStream, PqStream, Stream};
use anyhow::{bail, Context};
use crate::cplane_api::{CPlaneApi, DatabaseInfo};
use crate::ProxyState;
use anyhow::{anyhow, bail, Context};
use lazy_static::lazy_static;
use std::sync::Arc;
use tokio::io::{AsyncRead, AsyncWrite};
use tokio::net::TcpStream;
use rand::prelude::StdRng;
use rand::{Rng, SeedableRng};
use std::cell::Cell;
use std::collections::HashMap;
use std::net::{SocketAddr, TcpStream};
use std::sync::Mutex;
use std::{io, thread};
use tokio_postgres::NoTls;
use zenith_metrics::{new_common_metric_name, register_int_counter, IntCounter};
use zenith_utils::pq_proto::{BeMessage as Be, *};
use zenith_utils::postgres_backend::{self, PostgresBackend, ProtoState, Stream};
use zenith_utils::pq_proto::{BeMessage as Be, FeMessage as Fe, *};
use zenith_utils::sock_split::{ReadStream, WriteStream};
lazy_static! {
static ref NUM_CONNECTIONS_ACCEPTED_COUNTER: IntCounter = register_int_counter!(
new_common_metric_name("num_connections_accepted"),
"Number of TCP client connections accepted."
)
.unwrap();
static ref NUM_CONNECTIONS_CLOSED_COUNTER: IntCounter = register_int_counter!(
new_common_metric_name("num_connections_closed"),
"Number of TCP client connections closed."
)
.unwrap();
static ref NUM_BYTES_PROXIED_COUNTER: IntCounter = register_int_counter!(
new_common_metric_name("num_bytes_proxied"),
"Number of bytes sent/received between any client and backend."
)
.unwrap();
struct CancelClosure {
socket_addr: SocketAddr,
cancel_token: tokio_postgres::CancelToken,
}
async fn log_error<R, F>(future: F) -> F::Output
where
F: std::future::Future<Output = anyhow::Result<R>>,
{
future.await.map_err(|err| {
println!("error: {}", err);
err
})
}
pub async fn thread_main(
config: &'static ProxyConfig,
listener: tokio::net::TcpListener,
) -> anyhow::Result<()> {
scopeguard::defer! {
println!("proxy has shut down");
}
let cancel_map = Arc::new(CancelMap::default());
loop {
let (socket, peer_addr) = listener.accept().await?;
println!("accepted connection from {}", peer_addr);
let cancel_map = Arc::clone(&cancel_map);
tokio::spawn(log_error(async move {
socket
.set_nodelay(true)
.context("failed to set socket option")?;
handle_client(config, &cancel_map, socket).await
}));
}
}
async fn handle_client(
config: &ProxyConfig,
cancel_map: &CancelMap,
stream: impl AsyncRead + AsyncWrite + Unpin,
) -> anyhow::Result<()> {
// The `closed` counter will increase when this future is destroyed.
NUM_CONNECTIONS_ACCEPTED_COUNTER.inc();
scopeguard::defer! {
NUM_CONNECTIONS_CLOSED_COUNTER.inc();
}
let tls = config.tls_config.clone();
if let Some((client, creds)) = handshake(stream, tls, cancel_map).await? {
cancel_map
.with_session(|session| async {
connect_client_to_db(config, session, client, creds).await
})
.await?;
}
Ok(())
}
/// Handle a connection from one client.
/// For better testing experience, `stream` can be
/// any object satisfying the traits.
async fn handshake<S: AsyncRead + AsyncWrite + Unpin>(
stream: S,
mut tls: Option<TlsConfig>,
cancel_map: &CancelMap,
) -> anyhow::Result<Option<(PqStream<Stream<S>>, auth::ClientCredentials)>> {
// Client may try upgrading to each protocol only once
let (mut tried_ssl, mut tried_gss) = (false, false);
let mut stream = PqStream::new(Stream::from_raw(stream));
loop {
let msg = stream.read_startup_packet().await?;
println!("got message: {:?}", msg);
use FeStartupPacket::*;
match msg {
SslRequest => match stream.get_ref() {
Stream::Raw { .. } if !tried_ssl => {
tried_ssl = true;
// We can't perform TLS handshake without a config
let enc = tls.is_some();
stream.write_message(&Be::EncryptionResponse(enc)).await?;
if let Some(tls) = tls.take() {
// Upgrade raw stream into a secure TLS-backed stream.
// NOTE: We've consumed `tls`; this fact will be used later.
stream = PqStream::new(stream.into_inner().upgrade(tls).await?);
}
}
_ => bail!("protocol violation"),
},
GssEncRequest => match stream.get_ref() {
Stream::Raw { .. } if !tried_gss => {
tried_gss = true;
// Currently, we don't support GSSAPI
stream.write_message(&Be::EncryptionResponse(false)).await?;
}
_ => bail!("protocol violation"),
},
StartupMessage { params, .. } => {
// Check that the config has been consumed during upgrade
// OR we didn't provide it at all (for dev purposes).
if tls.is_some() {
let msg = "connection is insecure (try using `sslmode=require`)";
stream.write_message(&Be::ErrorResponse(msg)).await?;
bail!(msg);
}
break Ok(Some((stream, params.try_into()?)));
}
CancelRequest(cancel_key_data) => {
cancel_map.cancel_session(cancel_key_data).await?;
break Ok(None);
}
impl CancelClosure {
async fn try_cancel_query(&self) {
if let Ok(socket) = tokio::net::TcpStream::connect(self.socket_addr).await {
// NOTE ignoring the result because:
// 1. This is a best effort attempt, the database doesn't have to listen
// 2. Being opaque about errors here helps avoid leaking info to unauthenticated user
let _ = self.cancel_token.cancel_query_raw(socket, NoTls).await;
}
}
}
async fn connect_client_to_db(
config: &ProxyConfig,
session: cancellation::Session<'_>,
mut client: PqStream<impl AsyncRead + AsyncWrite + Unpin>,
creds: auth::ClientCredentials,
lazy_static! {
// Enables serving CancelRequests
static ref CANCEL_MAP: Mutex<HashMap<CancelKeyData, CancelClosure>> = Mutex::new(HashMap::new());
// Metrics
static ref NUM_CONNECTIONS_ACCEPTED_COUNTER: IntCounter = register_int_counter!(
new_common_metric_name("num_connections_accepted"),
"Number of TCP client connections accepted."
).unwrap();
static ref NUM_CONNECTIONS_CLOSED_COUNTER: IntCounter = register_int_counter!(
new_common_metric_name("num_connections_closed"),
"Number of TCP client connections closed."
).unwrap();
static ref NUM_CONNECTIONS_FAILED_COUNTER: IntCounter = register_int_counter!(
new_common_metric_name("num_connections_failed"),
"Number of TCP client connections that closed due to error."
).unwrap();
static ref NUM_BYTES_PROXIED_COUNTER: IntCounter = register_int_counter!(
new_common_metric_name("num_bytes_proxied"),
"Number of bytes sent/received between any client and backend."
).unwrap();
}
thread_local! {
// Used to clean up the CANCEL_MAP. Might not be necessary if we use tokio thread pool in main loop.
static THREAD_CANCEL_KEY_DATA: Cell<Option<CancelKeyData>> = Cell::new(None);
}
///
/// Main proxy listener loop.
///
/// Listens for connections, and launches a new handler thread for each.
///
pub fn thread_main(
state: &'static ProxyState,
listener: std::net::TcpListener,
) -> anyhow::Result<()> {
let db_info = creds.authenticate(config, &mut client).await?;
let (db, version, cancel_closure) = connect_to_db(db_info).await?;
let cancel_key_data = session.enable_cancellation(cancel_closure);
loop {
let (socket, peer_addr) = listener.accept()?;
println!("accepted connection from {}", peer_addr);
NUM_CONNECTIONS_ACCEPTED_COUNTER.inc();
socket.set_nodelay(true).unwrap();
client
.write_message_noflush(&BeMessage::ParameterStatus(
BeParameterStatusMessage::ServerVersion(&version),
))?
.write_message_noflush(&Be::BackendKeyData(cancel_key_data))?
.write_message(&BeMessage::ReadyForQuery)
.await?;
// TODO Use a threadpool instead. Maybe use tokio's threadpool by
// spawning a future into its runtime. Tokio's JoinError should
// allow us to handle cleanup properly even if the future panics.
thread::Builder::new()
.name("Proxy thread".into())
.spawn(move || {
if let Err(err) = proxy_conn_main(state, socket) {
NUM_CONNECTIONS_FAILED_COUNTER.inc();
println!("error: {}", err);
}
// This function will be called for writes to either direction.
fn inc_proxied(cnt: usize) {
// Consider inventing something more sophisticated
// if this ever becomes a bottleneck (cacheline bouncing).
NUM_BYTES_PROXIED_COUNTER.inc_by(cnt as u64);
// Clean up CANCEL_MAP.
NUM_CONNECTIONS_CLOSED_COUNTER.inc();
THREAD_CANCEL_KEY_DATA.with(|cell| {
if let Some(cancel_key_data) = cell.get() {
CANCEL_MAP.lock().unwrap().remove(&cancel_key_data);
};
});
})?;
}
}
// TODO: clean up fields
struct ProxyConnection {
state: &'static ProxyState,
psql_session_id: String,
pgb: PostgresBackend,
}
pub fn proxy_conn_main(state: &'static ProxyState, socket: TcpStream) -> anyhow::Result<()> {
let conn = ProxyConnection {
state,
psql_session_id: hex::encode(rand::random::<[u8; 8]>()),
pgb: PostgresBackend::new(
socket,
postgres_backend::AuthType::MD5,
state.conf.ssl_config.clone(),
false,
)?,
};
let (client, server) = match conn.handle_client()? {
Some(x) => x,
None => return Ok(()),
};
let server = zenith_utils::sock_split::BidiStream::from_tcp(server);
let client = match client {
Stream::Bidirectional(bidi_stream) => bidi_stream,
_ => panic!("invalid stream type"),
};
proxy(client.split(), server.split())
}
impl ProxyConnection {
/// Returns Ok(None) when connection was successfully closed.
fn handle_client(mut self) -> anyhow::Result<Option<(Stream, TcpStream)>> {
let mut authenticate = || {
let (username, dbname) = match self.handle_startup()? {
Some(x) => x,
None => return Ok(None),
};
// Both scenarios here should end up producing database credentials
if username.ends_with("@zenith") {
self.handle_existing_user(&username, &dbname).map(Some)
} else {
self.handle_new_user().map(Some)
}
};
let conn = match authenticate() {
Ok(Some(db_info)) => connect_to_db(db_info),
Ok(None) => return Ok(None),
Err(e) => {
// Report the error to the client
self.pgb.write_message(&Be::ErrorResponse(&e.to_string()))?;
bail!("failed to handle client: {:?}", e);
}
};
// We'll get rid of this once migration to async is complete
let (pg_version, db_stream) = {
let runtime = tokio::runtime::Builder::new_current_thread()
.enable_all()
.build()?;
let (pg_version, stream, cancel_key_data) = runtime.block_on(conn)?;
self.pgb
.write_message(&BeMessage::BackendKeyData(cancel_key_data))?;
let stream = stream.into_std()?;
stream.set_nonblocking(false)?;
(pg_version, stream)
};
// Let the client send new requests
self.pgb
.write_message_noflush(&BeMessage::ParameterStatus(
BeParameterStatusMessage::ServerVersion(&pg_version),
))?
.write_message(&Be::ReadyForQuery)?;
Ok(Some((self.pgb.into_stream(), db_stream)))
}
let mut db = MetricsStream::new(db, inc_proxied);
let mut client = MetricsStream::new(client.into_inner(), inc_proxied);
let _ = tokio::io::copy_bidirectional(&mut client, &mut db).await?;
/// Returns Ok(None) when connection was successfully closed.
fn handle_startup(&mut self) -> anyhow::Result<Option<(String, String)>> {
let have_tls = self.pgb.tls_config.is_some();
let mut encrypted = false;
loop {
let msg = match self.pgb.read_message()? {
Some(Fe::StartupPacket(msg)) => msg,
None => bail!("connection is lost"),
bad => bail!("unexpected message type: {:?}", bad),
};
println!("got message: {:?}", msg);
match msg {
FeStartupPacket::GssEncRequest => {
self.pgb.write_message(&Be::EncryptionResponse(false))?;
}
FeStartupPacket::SslRequest => {
self.pgb.write_message(&Be::EncryptionResponse(have_tls))?;
if have_tls {
self.pgb.start_tls()?;
encrypted = true;
}
}
FeStartupPacket::StartupMessage { mut params, .. } => {
if have_tls && !encrypted {
bail!("must connect with TLS");
}
let mut get_param = |key| {
params
.remove(key)
.with_context(|| format!("{} is missing in startup packet", key))
};
return Ok(Some((get_param("user")?, get_param("database")?)));
}
FeStartupPacket::CancelRequest(cancel_key_data) => {
if let Some(cancel_closure) = CANCEL_MAP.lock().unwrap().get(&cancel_key_data) {
let runtime = tokio::runtime::Builder::new_current_thread()
.enable_all()
.build()
.unwrap();
runtime.block_on(cancel_closure.try_cancel_query());
}
return Ok(None);
}
}
}
}
fn handle_existing_user(&mut self, user: &str, db: &str) -> anyhow::Result<DatabaseInfo> {
let md5_salt = rand::random::<[u8; 4]>();
// Ask password
self.pgb
.write_message(&Be::AuthenticationMD5Password(&md5_salt))?;
self.pgb.state = ProtoState::Authentication; // XXX
// Check password
let msg = match self.pgb.read_message()? {
Some(Fe::PasswordMessage(msg)) => msg,
None => bail!("connection is lost"),
bad => bail!("unexpected message type: {:?}", bad),
};
println!("got message: {:?}", msg);
let (_trailing_null, md5_response) = msg
.split_last()
.ok_or_else(|| anyhow!("unexpected password message"))?;
let cplane = CPlaneApi::new(&self.state.conf.auth_endpoint, &self.state.waiters);
let db_info = cplane.authenticate_proxy_request(
user,
db,
md5_response,
&md5_salt,
&self.psql_session_id,
)?;
self.pgb
.write_message_noflush(&Be::AuthenticationOk)?
.write_message_noflush(&BeParameterStatusMessage::encoding())?;
Ok(db_info)
}
fn handle_new_user(&mut self) -> anyhow::Result<DatabaseInfo> {
let greeting = hello_message(&self.state.conf.redirect_uri, &self.psql_session_id);
// First, register this session
let waiter = self.state.waiters.register(self.psql_session_id.clone());
// Give user a URL to spawn a new database
self.pgb
.write_message_noflush(&Be::AuthenticationOk)?
.write_message_noflush(&BeParameterStatusMessage::encoding())?
.write_message(&Be::NoticeResponse(greeting))?;
// Wait for web console response
let db_info = waiter.wait()?.map_err(|e| anyhow!(e))?;
self.pgb
.write_message_noflush(&Be::NoticeResponse("Connecting to database.".into()))?;
Ok(db_info)
}
}
fn hello_message(redirect_uri: &str, session_id: &str) -> String {
format!(
concat![
"☀️ Welcome to Zenith!\n",
"To proceed with database creation, open the following link:\n\n",
" {redirect_uri}{session_id}\n\n",
"It needs to be done once and we will send you '.pgpass' file,\n",
"which will allow you to access or create ",
"databases without opening your web browser."
],
redirect_uri = redirect_uri,
session_id = session_id,
)
}
/// Create a TCP connection to a postgres database, authenticate with it, and receive the ReadyForQuery message
async fn connect_to_db(
db_info: DatabaseInfo,
) -> anyhow::Result<(String, tokio::net::TcpStream, CancelKeyData)> {
// Make raw connection. When connect_raw finishes we've received ReadyForQuery.
let socket_addr = db_info.socket_addr()?;
let mut socket = tokio::net::TcpStream::connect(socket_addr).await?;
let config = tokio_postgres::Config::from(db_info);
// NOTE We effectively ignore some ParameterStatus and NoticeResponse
// messages here. Not sure if that could break something.
let (client, conn) = config.connect_raw(&mut socket, NoTls).await?;
// Save info for potentially cancelling the query later
let mut rng = StdRng::from_entropy();
let cancel_key_data = CancelKeyData {
// HACK We'd rather get the real backend_pid but tokio_postgres doesn't
// expose it and we don't want to do another roundtrip to query
// for it. The client will be able to notice that this is not the
// actual backend_pid, but backend_pid is not used for anything
// so it doesn't matter.
backend_pid: rng.gen(),
cancel_key: rng.gen(),
};
let cancel_closure = CancelClosure {
socket_addr,
cancel_token: client.cancel_token(),
};
CANCEL_MAP
.lock()
.unwrap()
.insert(cancel_key_data, cancel_closure);
THREAD_CANCEL_KEY_DATA.with(|cell| {
let prev_value = cell.replace(Some(cancel_key_data));
assert!(
prev_value.is_none(),
"THREAD_CANCEL_KEY_DATA was already set"
);
});
let version = conn.parameter("server_version").unwrap();
Ok((version.into(), socket, cancel_key_data))
}
/// Concurrently proxy both directions of the client and server connections
fn proxy(
(client_read, client_write): (ReadStream, WriteStream),
(server_read, server_write): (ReadStream, WriteStream),
) -> anyhow::Result<()> {
fn do_proxy(mut reader: impl io::Read, mut writer: WriteStream) -> io::Result<u64> {
/// FlushWriter will make sure that every message is sent as soon as possible
struct FlushWriter<W>(W);
impl<W: io::Write> io::Write for FlushWriter<W> {
fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
// `std::io::copy` is guaranteed to exit if we return an error,
// so we can afford to lose `res` in case `flush` fails
let res = self.0.write(buf);
if let Ok(count) = res {
NUM_BYTES_PROXIED_COUNTER.inc_by(count as u64);
self.flush()?;
}
res
}
fn flush(&mut self) -> io::Result<()> {
self.0.flush()
}
}
let res = std::io::copy(&mut reader, &mut FlushWriter(&mut writer));
writer.shutdown(std::net::Shutdown::Both)?;
res
}
let client_to_server_jh = thread::spawn(move || do_proxy(client_read, server_write));
do_proxy(server_read, client_write)?;
client_to_server_jh.join().unwrap()?;
Ok(())
}
/// Connect to a corresponding compute node.
async fn connect_to_db(
db_info: DatabaseInfo,
) -> anyhow::Result<(TcpStream, String, CancelClosure)> {
// TODO: establish a secure connection to the DB
let socket_addr = db_info.socket_addr()?;
let mut socket = TcpStream::connect(socket_addr).await?;
let (client, conn) = tokio_postgres::Config::from(db_info)
.connect_raw(&mut socket, NoTls)
.await?;
let version = conn
.parameter("server_version")
.context("failed to fetch postgres server version")?
.into();
let cancel_closure = CancelClosure::new(socket_addr, client.cancel_token());
Ok((socket, version, cancel_closure))
}
#[cfg(test)]
mod tests {
use super::*;
use tokio::io::DuplexStream;
use tokio_postgres::config::SslMode;
use tokio_postgres::tls::MakeTlsConnect;
use tokio_postgres_rustls::MakeRustlsConnect;
async fn dummy_proxy(
client: impl AsyncRead + AsyncWrite + Unpin,
tls: Option<TlsConfig>,
) -> anyhow::Result<()> {
let cancel_map = CancelMap::default();
// TODO: add some infra + tests for credentials
let (mut stream, _creds) = handshake(client, tls, &cancel_map)
.await?
.context("no stream")?;
stream
.write_message_noflush(&Be::AuthenticationOk)?
.write_message_noflush(&BeParameterStatusMessage::encoding())?
.write_message(&BeMessage::ReadyForQuery)
.await?;
Ok(())
}
fn generate_certs(
hostname: &str,
) -> anyhow::Result<(rustls::Certificate, rustls::Certificate, rustls::PrivateKey)> {
let ca = rcgen::Certificate::from_params({
let mut params = rcgen::CertificateParams::default();
params.is_ca = rcgen::IsCa::Ca(rcgen::BasicConstraints::Unconstrained);
params
})?;
let cert = rcgen::generate_simple_self_signed(vec![hostname.into()])?;
Ok((
rustls::Certificate(ca.serialize_der()?),
rustls::Certificate(cert.serialize_der_with_signer(&ca)?),
rustls::PrivateKey(cert.serialize_private_key_der()),
))
}
#[tokio::test]
async fn handshake_tls_is_enforced_by_proxy() -> anyhow::Result<()> {
let (client, server) = tokio::io::duplex(1024);
let server_config = {
let (_ca, cert, key) = generate_certs("localhost")?;
let mut config = rustls::ServerConfig::new(rustls::NoClientAuth::new());
config.set_single_cert(vec![cert], key)?;
config
};
let proxy = tokio::spawn(dummy_proxy(client, Some(server_config.into())));
tokio_postgres::Config::new()
.user("john_doe")
.dbname("earth")
.ssl_mode(SslMode::Disable)
.connect_raw(server, NoTls)
.await
.err() // -> Option<E>
.context("client shouldn't be able to connect")?;
proxy
.await?
.err() // -> Option<E>
.context("server shouldn't accept client")?;
Ok(())
}
#[tokio::test]
async fn handshake_tls() -> anyhow::Result<()> {
let (client, server) = tokio::io::duplex(1024);
let (ca, cert, key) = generate_certs("localhost")?;
let server_config = {
let mut config = rustls::ServerConfig::new(rustls::NoClientAuth::new());
config.set_single_cert(vec![cert], key)?;
config
};
let proxy = tokio::spawn(dummy_proxy(client, Some(server_config.into())));
let client_config = {
let mut config = rustls::ClientConfig::new();
config.root_store.add(&ca)?;
config
};
let mut mk = MakeRustlsConnect::new(client_config);
let tls = MakeTlsConnect::<DuplexStream>::make_tls_connect(&mut mk, "localhost")?;
let (_client, _conn) = tokio_postgres::Config::new()
.user("john_doe")
.dbname("earth")
.ssl_mode(SslMode::Require)
.connect_raw(server, tls)
.await?;
proxy.await?
}
#[tokio::test]
async fn handshake_raw() -> anyhow::Result<()> {
let (client, server) = tokio::io::duplex(1024);
let proxy = tokio::spawn(dummy_proxy(client, None));
let (_client, _conn) = tokio_postgres::Config::new()
.user("john_doe")
.dbname("earth")
.ssl_mode(SslMode::Prefer)
.connect_raw(server, NoTls)
.await?;
proxy.await?
}
}

View File

@@ -1,46 +1,15 @@
use crate::cplane_api::DatabaseInfo;
use anyhow::{anyhow, ensure, Context};
use rustls::{internal::pemfile, NoClientAuth, ProtocolVersion, ServerConfig};
use std::net::SocketAddr;
use std::str::FromStr;
use std::sync::Arc;
pub type TlsConfig = Arc<ServerConfig>;
#[non_exhaustive]
pub enum ClientAuthMethod {
Password,
Link,
/// Use password auth only if username ends with "@zenith"
Mixed,
}
pub enum RouterConfig {
Static { host: String, port: u16 },
Dynamic(ClientAuthMethod),
}
impl FromStr for ClientAuthMethod {
type Err = anyhow::Error;
fn from_str(s: &str) -> anyhow::Result<Self> {
use ClientAuthMethod::*;
match s {
"password" => Ok(Password),
"link" => Ok(Link),
"mixed" => Ok(Mixed),
_ => Err(anyhow::anyhow!("Invlid option for router")),
}
}
}
pub type SslConfig = Arc<ServerConfig>;
pub struct ProxyConfig {
/// main entrypoint for users to connect to
pub proxy_address: SocketAddr,
/// method of assigning compute nodes
pub router_config: RouterConfig,
/// internally used for status and prometheus metrics
pub http_address: SocketAddr,
@@ -55,10 +24,26 @@ pub struct ProxyConfig {
/// control plane address where we would check auth.
pub auth_endpoint: String,
pub tls_config: Option<TlsConfig>,
pub ssl_config: Option<SslConfig>,
}
pub fn configure_ssl(key_path: &str, cert_path: &str) -> anyhow::Result<TlsConfig> {
pub type ProxyWaiters = crate::waiters::Waiters<Result<DatabaseInfo, String>>;
pub struct ProxyState {
pub conf: ProxyConfig,
pub waiters: ProxyWaiters,
}
impl ProxyState {
pub fn new(conf: ProxyConfig) -> Self {
Self {
conf,
waiters: ProxyWaiters::default(),
}
}
}
pub fn configure_ssl(key_path: &str, cert_path: &str) -> anyhow::Result<SslConfig> {
let key = {
let key_bytes = std::fs::read(key_path).context("SSL key file")?;
let mut keys = pemfile::pkcs8_private_keys(&mut &key_bytes[..])

View File

@@ -1,230 +0,0 @@
use anyhow::Context;
use bytes::BytesMut;
use pin_project_lite::pin_project;
use rustls::ServerConfig;
use std::pin::Pin;
use std::sync::Arc;
use std::{io, task};
use tokio::io::{AsyncRead, AsyncWrite, AsyncWriteExt, ReadBuf};
use tokio_rustls::server::TlsStream;
use zenith_utils::pq_proto::{BeMessage, FeMessage, FeStartupPacket};
pin_project! {
/// Stream wrapper which implements libpq's protocol.
/// NOTE: This object deliberately doesn't implement [`AsyncRead`]
/// or [`AsyncWrite`] to prevent subtle errors (e.g. trying
/// to pass random malformed bytes through the connection).
pub struct PqStream<S> {
#[pin]
stream: S,
buffer: BytesMut,
}
}
impl<S> PqStream<S> {
/// Construct a new libpq protocol wrapper.
pub fn new(stream: S) -> Self {
Self {
stream,
buffer: Default::default(),
}
}
/// Extract the underlying stream.
pub fn into_inner(self) -> S {
self.stream
}
/// Get a reference to the underlying stream.
pub fn get_ref(&self) -> &S {
&self.stream
}
}
impl<S: AsyncRead + Unpin> PqStream<S> {
/// Receive [`FeStartupPacket`], which is a first packet sent by a client.
pub async fn read_startup_packet(&mut self) -> anyhow::Result<FeStartupPacket> {
match FeStartupPacket::read_fut(&mut self.stream).await? {
Some(FeMessage::StartupPacket(packet)) => Ok(packet),
None => anyhow::bail!("connection is lost"),
other => anyhow::bail!("bad message type: {:?}", other),
}
}
pub async fn read_message(&mut self) -> anyhow::Result<FeMessage> {
FeMessage::read_fut(&mut self.stream)
.await?
.context("connection is lost")
}
}
impl<S: AsyncWrite + Unpin> PqStream<S> {
/// Write the message into an internal buffer, but don't flush the underlying stream.
pub fn write_message_noflush<'a>(&mut self, message: &BeMessage<'a>) -> io::Result<&mut Self> {
BeMessage::write(&mut self.buffer, message)?;
Ok(self)
}
/// Write the message into an internal buffer and flush it.
pub async fn write_message<'a>(&mut self, message: &BeMessage<'a>) -> io::Result<&mut Self> {
self.write_message_noflush(message)?;
self.flush().await?;
Ok(self)
}
/// Flush the output buffer into the underlying stream.
pub async fn flush(&mut self) -> io::Result<&mut Self> {
self.stream.write_all(&self.buffer).await?;
self.buffer.clear();
self.stream.flush().await?;
Ok(self)
}
}
pin_project! {
/// Wrapper for upgrading raw streams into secure streams.
/// NOTE: it should be possible to decompose this object as necessary.
#[project = StreamProj]
pub enum Stream<S> {
/// We always begin with a raw stream,
/// which may then be upgraded into a secure stream.
Raw { #[pin] raw: S },
/// We box [`TlsStream`] since it can be quite large.
Tls { #[pin] tls: Box<TlsStream<S>> },
}
}
impl<S> Stream<S> {
/// Construct a new instance from a raw stream.
pub fn from_raw(raw: S) -> Self {
Self::Raw { raw }
}
}
impl<S: AsyncRead + AsyncWrite + Unpin> Stream<S> {
/// If possible, upgrade raw stream into a secure TLS-based stream.
pub async fn upgrade(self, cfg: Arc<ServerConfig>) -> anyhow::Result<Self> {
match self {
Stream::Raw { raw } => {
let tls = Box::new(tokio_rustls::TlsAcceptor::from(cfg).accept(raw).await?);
Ok(Stream::Tls { tls })
}
Stream::Tls { .. } => anyhow::bail!("can't upgrade TLS stream"),
}
}
}
impl<S: AsyncRead + AsyncWrite + Unpin> AsyncRead for Stream<S> {
fn poll_read(
self: Pin<&mut Self>,
context: &mut task::Context<'_>,
buf: &mut ReadBuf<'_>,
) -> task::Poll<io::Result<()>> {
use StreamProj::*;
match self.project() {
Raw { raw } => raw.poll_read(context, buf),
Tls { tls } => tls.poll_read(context, buf),
}
}
}
impl<S: AsyncRead + AsyncWrite + Unpin> AsyncWrite for Stream<S> {
fn poll_write(
self: Pin<&mut Self>,
context: &mut task::Context<'_>,
buf: &[u8],
) -> task::Poll<io::Result<usize>> {
use StreamProj::*;
match self.project() {
Raw { raw } => raw.poll_write(context, buf),
Tls { tls } => tls.poll_write(context, buf),
}
}
fn poll_flush(
self: Pin<&mut Self>,
context: &mut task::Context<'_>,
) -> task::Poll<io::Result<()>> {
use StreamProj::*;
match self.project() {
Raw { raw } => raw.poll_flush(context),
Tls { tls } => tls.poll_flush(context),
}
}
fn poll_shutdown(
self: Pin<&mut Self>,
context: &mut task::Context<'_>,
) -> task::Poll<io::Result<()>> {
use StreamProj::*;
match self.project() {
Raw { raw } => raw.poll_shutdown(context),
Tls { tls } => tls.poll_shutdown(context),
}
}
}
pin_project! {
/// This stream tracks all writes and calls user provided
/// callback when the underlying stream is flushed.
pub struct MetricsStream<S, W> {
#[pin]
stream: S,
write_count: usize,
inc_write_count: W,
}
}
impl<S, W> MetricsStream<S, W> {
pub fn new(stream: S, inc_write_count: W) -> Self {
Self {
stream,
write_count: 0,
inc_write_count,
}
}
}
impl<S: AsyncRead + Unpin, W> AsyncRead for MetricsStream<S, W> {
fn poll_read(
self: Pin<&mut Self>,
context: &mut task::Context<'_>,
buf: &mut ReadBuf<'_>,
) -> task::Poll<io::Result<()>> {
self.project().stream.poll_read(context, buf)
}
}
impl<S: AsyncWrite + Unpin, W: FnMut(usize)> AsyncWrite for MetricsStream<S, W> {
fn poll_write(
self: Pin<&mut Self>,
context: &mut task::Context<'_>,
buf: &[u8],
) -> task::Poll<io::Result<usize>> {
let this = self.project();
this.stream.poll_write(context, buf).map_ok(|cnt| {
// Increment the write count.
*this.write_count += cnt;
cnt
})
}
fn poll_flush(
self: Pin<&mut Self>,
context: &mut task::Context<'_>,
) -> task::Poll<io::Result<()>> {
let this = self.project();
this.stream.poll_flush(context).map_ok(|()| {
// Call the user provided callback and reset the write count.
(this.inc_write_count)(*this.write_count);
*this.write_count = 0;
})
}
fn poll_shutdown(
self: Pin<&mut Self>,
context: &mut task::Context<'_>,
) -> task::Poll<io::Result<()>> {
self.project().stream.poll_shutdown(context)
}
}

View File

@@ -1,12 +1,8 @@
use anyhow::{anyhow, Context};
use hashbrown::HashMap;
use parking_lot::Mutex;
use pin_project_lite::pin_project;
use std::pin::Pin;
use std::task;
use tokio::sync::oneshot;
use anyhow::Context;
use std::collections::HashMap;
use std::sync::{mpsc, Mutex};
pub struct Waiters<T>(pub(self) Mutex<HashMap<String, oneshot::Sender<T>>>);
pub struct Waiters<T>(pub(self) Mutex<HashMap<String, mpsc::Sender<T>>>);
impl<T> Default for Waiters<T> {
fn default() -> Self {
@@ -15,86 +11,48 @@ impl<T> Default for Waiters<T> {
}
impl<T> Waiters<T> {
pub fn register(&self, key: String) -> anyhow::Result<Waiter<T>> {
let (tx, rx) = oneshot::channel();
pub fn register(&self, key: String) -> Waiter<T> {
let (tx, rx) = mpsc::channel();
self.0
.lock()
.try_insert(key.clone(), tx)
.map_err(|_| anyhow!("waiter already registered"))?;
// TODO: use `try_insert` (unstable)
let prev = self.0.lock().unwrap().insert(key.clone(), tx);
assert!(matches!(prev, None)); // assert_matches! is nightly-only
Ok(Waiter {
Waiter {
receiver: rx,
guard: DropKey {
registry: self,
key,
},
})
registry: self,
key,
}
}
pub fn notify(&self, key: &str, value: T) -> anyhow::Result<()>
where
T: Send + Sync,
T: Send + Sync + 'static,
{
let tx = self
.0
.lock()
.unwrap()
.remove(key)
.with_context(|| format!("key {} not found", key))?;
tx.send(value).map_err(|_| anyhow!("waiter channel hangup"))
tx.send(value).context("channel hangup")
}
}
struct DropKey<'a, T> {
key: String,
pub struct Waiter<'a, T> {
receiver: mpsc::Receiver<T>,
registry: &'a Waiters<T>,
key: String,
}
impl<'a, T> Drop for DropKey<'a, T> {
impl<T> Waiter<'_, T> {
pub fn wait(self) -> anyhow::Result<T> {
self.receiver.recv().context("channel hangup")
}
}
impl<T> Drop for Waiter<'_, T> {
fn drop(&mut self) {
self.registry.0.lock().remove(&self.key);
}
}
pin_project! {
pub struct Waiter<'a, T> {
#[pin]
receiver: oneshot::Receiver<T>,
guard: DropKey<'a, T>,
}
}
impl<T> std::future::Future for Waiter<'_, T> {
type Output = anyhow::Result<T>;
fn poll(self: Pin<&mut Self>, cx: &mut task::Context<'_>) -> task::Poll<Self::Output> {
self.project()
.receiver
.poll(cx)
.map_err(|_| anyhow!("channel hangup"))
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::sync::Arc;
#[tokio::test]
async fn test_waiter() -> anyhow::Result<()> {
let waiters = Arc::new(Waiters::default());
let key = "Key";
let waiter = waiters.register(key.to_owned())?;
let waiters = Arc::clone(&waiters);
let notifier = tokio::spawn(async move {
waiters.notify(key, Default::default())?;
Ok(())
});
let () = waiter.await?;
notifier.await?
self.registry.0.lock().unwrap().remove(&self.key);
}
}

View File

@@ -21,7 +21,6 @@ types-psycopg2 = "^2.9.6"
boto3 = "^1.20.40"
boto3-stubs = "^1.20.40"
moto = {version = "^3.0.0", extras = ["server"]}
backoff = "^1.11.1"
[tool.poetry.dev-dependencies]
yapf = "==0.31.0"

View File

@@ -89,7 +89,7 @@ def test_foobar(zenith_env_builder: ZenithEnvBuilder):
# Now create the environment. This initializes the repository, and starts
# up the page server and the safekeepers
env = zenith_env_builder.init_start()
env = zenith_env_builder.init()
# Run the test
...

View File

@@ -1,49 +1,45 @@
from contextlib import closing
from typing import Iterator
from uuid import UUID, uuid4
from uuid import uuid4
import psycopg2
from fixtures.zenith_fixtures import ZenithEnvBuilder, ZenithPageserverApiException
from fixtures.zenith_fixtures import ZenithEnvBuilder
import pytest
pytest_plugins = ("fixtures.zenith_fixtures")
def test_pageserver_auth(zenith_env_builder: ZenithEnvBuilder):
zenith_env_builder.pageserver_auth_enabled = True
env = zenith_env_builder.init_start()
env = zenith_env_builder.init()
ps = env.pageserver
tenant_token = env.auth_keys.generate_tenant_token(env.initial_tenant.hex)
tenant_http_client = env.pageserver.http_client(tenant_token)
tenant_token = env.auth_keys.generate_tenant_token(env.initial_tenant)
invalid_tenant_token = env.auth_keys.generate_tenant_token(uuid4().hex)
invalid_tenant_http_client = env.pageserver.http_client(invalid_tenant_token)
management_token = env.auth_keys.generate_management_token()
management_http_client = env.pageserver.http_client(management_token)
# this does not invoke auth check and only decodes jwt and checks it for validity
# check both tokens
ps.safe_psql("set FOO", password=tenant_token)
ps.safe_psql("set FOO", password=management_token)
ps.safe_psql("status", password=tenant_token)
ps.safe_psql("status", password=management_token)
# tenant can create branches
tenant_http_client.branch_create(env.initial_tenant, 'new1', 'main')
ps.safe_psql(f"branch_create {env.initial_tenant} new1 main", password=tenant_token)
# console can create branches for tenant
management_http_client.branch_create(env.initial_tenant, 'new2', 'main')
ps.safe_psql(f"branch_create {env.initial_tenant} new2 main", password=management_token)
# fail to create branch using token with different tenant_id
with pytest.raises(ZenithPageserverApiException,
match='Forbidden: Tenant id mismatch. Permission denied'):
invalid_tenant_http_client.branch_create(env.initial_tenant, "new3", "main")
# fail to create branch using token with different tenantid
with pytest.raises(psycopg2.DatabaseError, match='Tenant id mismatch. Permission denied'):
ps.safe_psql(f"branch_create {env.initial_tenant} new2 main", password=invalid_tenant_token)
# create tenant using management token
management_http_client.tenant_create(uuid4())
ps.safe_psql(f"tenant_create {uuid4().hex}", password=management_token)
# fail to create tenant using tenant token
with pytest.raises(
ZenithPageserverApiException,
match='Forbidden: Attempt to access management api with tenant scope. Permission denied'
):
tenant_http_client.tenant_create(uuid4())
psycopg2.DatabaseError,
match='Attempt to access management api with tenant scope. Permission denied'):
ps.safe_psql(f"tenant_create {uuid4().hex}", password=tenant_token)
@pytest.mark.parametrize('with_wal_acceptors', [False, True])
@@ -51,10 +47,10 @@ def test_compute_auth_to_pageserver(zenith_env_builder: ZenithEnvBuilder, with_w
zenith_env_builder.pageserver_auth_enabled = True
if with_wal_acceptors:
zenith_env_builder.num_safekeepers = 3
env = zenith_env_builder.init_start()
env = zenith_env_builder.init()
branch = f"test_compute_auth_to_pageserver{with_wal_acceptors}"
env.zenith_cli.create_branch(branch, "main")
env.zenith_cli(["branch", branch, "main"])
pg = env.postgres.create_start(branch)

View File

@@ -1,154 +0,0 @@
from contextlib import closing, contextmanager
import psycopg2.extras
from fixtures.zenith_fixtures import ZenithEnvBuilder
from fixtures.log_helper import log
import os
import time
import asyncpg
from fixtures.zenith_fixtures import Postgres
import threading
pytest_plugins = ("fixtures.zenith_fixtures")
@contextmanager
def pg_cur(pg):
with closing(pg.connect()) as conn:
with conn.cursor() as cur:
yield cur
# Periodically check that all backpressure lags are below the configured threshold,
# assert if they are not.
# If the check query fails, stop the thread. Main thread should notice that and stop the test.
def check_backpressure(pg: Postgres, stop_event: threading.Event, polling_interval=5):
log.info("checks started")
with pg_cur(pg) as cur:
cur.execute("CREATE EXTENSION zenith") # TODO move it to zenith_fixtures?
cur.execute("select pg_size_bytes(current_setting('max_replication_write_lag'))")
res = cur.fetchone()
max_replication_write_lag_bytes = res[0]
log.info(f"max_replication_write_lag: {max_replication_write_lag_bytes} bytes")
cur.execute("select pg_size_bytes(current_setting('max_replication_flush_lag'))")
res = cur.fetchone()
max_replication_flush_lag_bytes = res[0]
log.info(f"max_replication_flush_lag: {max_replication_flush_lag_bytes} bytes")
cur.execute("select pg_size_bytes(current_setting('max_replication_apply_lag'))")
res = cur.fetchone()
max_replication_apply_lag_bytes = res[0]
log.info(f"max_replication_apply_lag: {max_replication_apply_lag_bytes} bytes")
with pg_cur(pg) as cur:
while not stop_event.is_set():
try:
cur.execute('''
select pg_wal_lsn_diff(pg_current_wal_flush_lsn(),received_lsn) as received_lsn_lag,
pg_wal_lsn_diff(pg_current_wal_flush_lsn(),disk_consistent_lsn) as disk_consistent_lsn_lag,
pg_wal_lsn_diff(pg_current_wal_flush_lsn(),remote_consistent_lsn) as remote_consistent_lsn_lag,
pg_size_pretty(pg_wal_lsn_diff(pg_current_wal_flush_lsn(),received_lsn)),
pg_size_pretty(pg_wal_lsn_diff(pg_current_wal_flush_lsn(),disk_consistent_lsn)),
pg_size_pretty(pg_wal_lsn_diff(pg_current_wal_flush_lsn(),remote_consistent_lsn))
from backpressure_lsns();
''')
res = cur.fetchone()
received_lsn_lag = res[0]
disk_consistent_lsn_lag = res[1]
remote_consistent_lsn_lag = res[2]
log.info(f"received_lsn_lag = {received_lsn_lag} ({res[3]}), "
f"disk_consistent_lsn_lag = {disk_consistent_lsn_lag} ({res[4]}), "
f"remote_consistent_lsn_lag = {remote_consistent_lsn_lag} ({res[5]})")
# Since feedback from pageserver is not immediate, we should allow some lag overflow
lag_overflow = 5 * 1024 * 1024 # 5MB
if max_replication_write_lag_bytes > 0:
assert received_lsn_lag < max_replication_write_lag_bytes + lag_overflow
if max_replication_flush_lag_bytes > 0:
assert disk_consistent_lsn_lag < max_replication_flush_lag_bytes + lag_overflow
if max_replication_apply_lag_bytes > 0:
assert remote_consistent_lsn_lag < max_replication_apply_lag_bytes + lag_overflow
time.sleep(polling_interval)
except Exception as e:
log.info(f"backpressure check query failed: {e}")
stop_event.set()
log.info('check thread stopped')
# This test illustrates how to tune backpressure to control the lag
# between the WAL flushed on compute node and WAL digested by pageserver.
#
# To test it, throttle walreceiver ingest using failpoint and run heavy write load.
# If backpressure is disabled or not tuned properly, the query will timeout, because the walreceiver cannot keep up.
# If backpressure is enabled and tuned properly, insertion will be throttled, but the query will not timeout.
def test_backpressure_received_lsn_lag(zenith_env_builder: ZenithEnvBuilder):
zenith_env_builder.num_safekeepers = 1
env = zenith_env_builder.init_start()
# Create a branch for us
env.zenith_cli.create_branch("test_backpressure", "main")
pg = env.postgres.create_start('test_backpressure',
config_lines=['max_replication_write_lag=30MB'])
log.info("postgres is running on 'test_backpressure' branch")
# setup check thread
check_stop_event = threading.Event()
check_thread = threading.Thread(target=check_backpressure, args=(pg, check_stop_event))
check_thread.start()
# Configure failpoint to slow down walreceiver ingest
with closing(env.pageserver.connect()) as psconn:
with psconn.cursor(cursor_factory=psycopg2.extras.DictCursor) as pscur:
pscur.execute("failpoints walreceiver-after-ingest=sleep(20)")
# FIXME
# Wait for the check thread to start
#
# Now if load starts too soon,
# check thread cannot auth, because it is not able to connect to the database
# because of the lag and waiting for lsn to replay to arrive.
time.sleep(2)
with pg_cur(pg) as cur:
# Create and initialize test table
cur.execute("CREATE TABLE foo(x bigint)")
inserts_to_do = 2000000
rows_inserted = 0
while check_thread.is_alive() and rows_inserted < inserts_to_do:
try:
cur.execute("INSERT INTO foo select from generate_series(1, 100000)")
rows_inserted += 100000
except Exception as e:
if check_thread.is_alive():
log.info('stopping check thread')
check_stop_event.set()
check_thread.join()
assert False, f"Exception {e} while inserting rows, but WAL lag is within configured threshold. That means backpressure is not tuned properly"
else:
assert False, f"Exception {e} while inserting rows and WAL lag overflowed configured threshold. That means backpressure doesn't work."
log.info(f"inserted {rows_inserted} rows")
if check_thread.is_alive():
log.info('stopping check thread')
check_stop_event.set()
check_thread.join()
log.info('check thread stopped')
else:
assert False, "WAL lag overflowed configured threshold. That means backpressure doesn't work."
#TODO test_backpressure_disk_consistent_lsn_lag. Play with pageserver's checkpoint settings
#TODO test_backpressure_remote_consistent_lsn_lag

View File

@@ -7,6 +7,8 @@ from fixtures.log_helper import log
from fixtures.utils import print_gc_result
from fixtures.zenith_fixtures import ZenithEnvBuilder
pytest_plugins = ("fixtures.zenith_fixtures")
#
# Create a couple of branches off the main branch, at a historical point in time.
@@ -19,10 +21,10 @@ def test_branch_behind(zenith_env_builder: ZenithEnvBuilder):
#
# See https://github.com/zenithdb/zenith/issues/1068
zenith_env_builder.num_safekeepers = 1
env = zenith_env_builder.init_start()
env = zenith_env_builder.init()
# Branch at the point where only 100 rows were inserted
env.zenith_cli.create_branch("test_branch_behind", "main")
env.zenith_cli(["branch", "test_branch_behind", "main"])
pgmain = env.postgres.create_start('test_branch_behind')
log.info("postgres is running on 'test_branch_behind' branch")
@@ -60,7 +62,7 @@ def test_branch_behind(zenith_env_builder: ZenithEnvBuilder):
log.info(f'LSN after 200100 rows: {lsn_b}')
# Branch at the point where only 100 rows were inserted
env.zenith_cli.create_branch("test_branch_behind_hundred", "test_branch_behind@" + lsn_a)
env.zenith_cli(["branch", "test_branch_behind_hundred", "test_branch_behind@" + lsn_a])
# Insert many more rows. This generates enough WAL to fill a few segments.
main_cur.execute('''
@@ -75,7 +77,7 @@ def test_branch_behind(zenith_env_builder: ZenithEnvBuilder):
log.info(f'LSN after 400100 rows: {lsn_c}')
# Branch at the point where only 200100 rows were inserted
env.zenith_cli.create_branch("test_branch_behind_more", "test_branch_behind@" + lsn_b)
env.zenith_cli(["branch", "test_branch_behind_more", "test_branch_behind@" + lsn_b])
pg_hundred = env.postgres.create_start("test_branch_behind_hundred")
pg_more = env.postgres.create_start("test_branch_behind_more")
@@ -99,7 +101,7 @@ def test_branch_behind(zenith_env_builder: ZenithEnvBuilder):
# Check bad lsn's for branching
# branch at segment boundary
env.zenith_cli.create_branch("test_branch_segment_boundary", "test_branch_behind@0/3000000")
env.zenith_cli(["branch", "test_branch_segment_boundary", "test_branch_behind@0/3000000"])
pg = env.postgres.create_start("test_branch_segment_boundary")
cur = pg.connect().cursor()
cur.execute('SELECT 1')
@@ -107,23 +109,23 @@ def test_branch_behind(zenith_env_builder: ZenithEnvBuilder):
# branch at pre-initdb lsn
with pytest.raises(Exception, match="invalid branch start lsn"):
env.zenith_cli.create_branch("test_branch_preinitdb", "main@0/42")
env.zenith_cli(["branch", "test_branch_preinitdb", "main@0/42"])
# branch at pre-ancestor lsn
with pytest.raises(Exception, match="less than timeline ancestor lsn"):
env.zenith_cli.create_branch("test_branch_preinitdb", "test_branch_behind@0/42")
env.zenith_cli(["branch", "test_branch_preinitdb", "test_branch_behind@0/42"])
# check that we cannot create branch based on garbage collected data
with closing(env.pageserver.connect()) as psconn:
with psconn.cursor(cursor_factory=psycopg2.extras.DictCursor) as pscur:
# call gc to advace latest_gc_cutoff_lsn
pscur.execute(f"do_gc {env.initial_tenant.hex} {timeline} 0")
pscur.execute(f"do_gc {env.initial_tenant} {timeline} 0")
row = pscur.fetchone()
print_gc_result(row)
with pytest.raises(Exception, match="invalid branch start lsn"):
# this gced_lsn is pretty random, so if gc is disabled this woudln't fail
env.zenith_cli.create_branch("test_branch_create_fail", f"test_branch_behind@{gced_lsn}")
env.zenith_cli(["branch", "test_branch_create_fail", f"test_branch_behind@{gced_lsn}"])
# check that after gc everything is still there
hundred_cur.execute('SELECT count(*) FROM foo')

View File

@@ -6,13 +6,16 @@ from contextlib import closing
from fixtures.zenith_fixtures import ZenithEnv
from fixtures.log_helper import log
pytest_plugins = ("fixtures.zenith_fixtures")
#
# Test compute node start after clog truncation
#
def test_clog_truncate(zenith_simple_env: ZenithEnv):
env = zenith_simple_env
env.zenith_cli.create_branch("test_clog_truncate", "empty")
# Create a branch for us
env.zenith_cli(["branch", "test_clog_truncate", "empty"])
# set agressive autovacuum to make sure that truncation will happen
config = [
@@ -62,8 +65,8 @@ def test_clog_truncate(zenith_simple_env: ZenithEnv):
# create new branch after clog truncation and start a compute node on it
log.info(f'create branch at lsn_after_truncation {lsn_after_truncation}')
env.zenith_cli.create_branch("test_clog_truncate_new",
"test_clog_truncate@" + lsn_after_truncation)
env.zenith_cli(
["branch", "test_clog_truncate_new", "test_clog_truncate@" + lsn_after_truncation])
pg2 = env.postgres.create_start('test_clog_truncate_new')
log.info('postgres is running on test_clog_truncate_new branch')

View File

@@ -3,13 +3,16 @@ from contextlib import closing
from fixtures.zenith_fixtures import ZenithEnv
from fixtures.log_helper import log
pytest_plugins = ("fixtures.zenith_fixtures")
#
# Test starting Postgres with custom options
#
def test_config(zenith_simple_env: ZenithEnv):
env = zenith_simple_env
env.zenith_cli.create_branch("test_config", "empty")
# Create a branch for us
env.zenith_cli(["branch", "test_config", "empty"])
# change config
pg = env.postgres.create_start('test_config', config_lines=['log_min_messages=debug1'])

View File

@@ -5,13 +5,15 @@ from contextlib import closing
from fixtures.zenith_fixtures import ZenithEnv, check_restored_datadir_content
from fixtures.log_helper import log
pytest_plugins = ("fixtures.zenith_fixtures")
#
# Test CREATE DATABASE when there have been relmapper changes
#
def test_createdb(zenith_simple_env: ZenithEnv):
env = zenith_simple_env
env.zenith_cli.create_branch("test_createdb", "empty")
env.zenith_cli(["branch", "test_createdb", "empty"])
pg = env.postgres.create_start('test_createdb')
log.info("postgres is running on 'test_createdb' branch")
@@ -27,7 +29,7 @@ def test_createdb(zenith_simple_env: ZenithEnv):
lsn = cur.fetchone()[0]
# Create a branch
env.zenith_cli.create_branch("test_createdb2", "test_createdb@" + lsn)
env.zenith_cli(["branch", "test_createdb2", "test_createdb@" + lsn])
pg2 = env.postgres.create_start('test_createdb2')
@@ -41,7 +43,7 @@ def test_createdb(zenith_simple_env: ZenithEnv):
#
def test_dropdb(zenith_simple_env: ZenithEnv, test_output_dir):
env = zenith_simple_env
env.zenith_cli.create_branch("test_dropdb", "empty")
env.zenith_cli(["branch", "test_dropdb", "empty"])
pg = env.postgres.create_start('test_dropdb')
log.info("postgres is running on 'test_dropdb' branch")
@@ -66,10 +68,10 @@ def test_dropdb(zenith_simple_env: ZenithEnv, test_output_dir):
lsn_after_drop = cur.fetchone()[0]
# Create two branches before and after database drop.
env.zenith_cli.create_branch("test_before_dropdb", "test_dropdb@" + lsn_before_drop)
env.zenith_cli(["branch", "test_before_dropdb", "test_dropdb@" + lsn_before_drop])
pg_before = env.postgres.create_start('test_before_dropdb')
env.zenith_cli.create_branch("test_after_dropdb", "test_dropdb@" + lsn_after_drop)
env.zenith_cli(["branch", "test_after_dropdb", "test_dropdb@" + lsn_after_drop])
pg_after = env.postgres.create_start('test_after_dropdb')
# Test that database exists on the branch before drop

View File

@@ -3,13 +3,15 @@ from contextlib import closing
from fixtures.zenith_fixtures import ZenithEnv
from fixtures.log_helper import log
pytest_plugins = ("fixtures.zenith_fixtures")
#
# Test CREATE USER to check shared catalog restore
#
def test_createuser(zenith_simple_env: ZenithEnv):
env = zenith_simple_env
env.zenith_cli.create_branch("test_createuser", "empty")
env.zenith_cli(["branch", "test_createuser", "empty"])
pg = env.postgres.create_start('test_createuser')
log.info("postgres is running on 'test_createuser' branch")
@@ -25,7 +27,7 @@ def test_createuser(zenith_simple_env: ZenithEnv):
lsn = cur.fetchone()[0]
# Create a branch
env.zenith_cli.create_branch("test_createuser2", "test_createuser@" + lsn)
env.zenith_cli(["branch", "test_createuser2", "test_createuser@" + lsn])
pg2 = env.postgres.create_start('test_createuser2')

View File

@@ -7,6 +7,8 @@ import random
from fixtures.zenith_fixtures import ZenithEnv, Postgres, Safekeeper
from fixtures.log_helper import log
pytest_plugins = ("fixtures.zenith_fixtures")
# Test configuration
#
# Create a table with {num_rows} rows, and perform {updates_to_perform} random
@@ -34,7 +36,7 @@ async def gc(env: ZenithEnv, timeline: str):
psconn = await env.pageserver.connect_async()
while updates_performed < updates_to_perform:
await psconn.execute(f"do_gc {env.initial_tenant.hex} {timeline} 0")
await psconn.execute(f"do_gc {env.initial_tenant} {timeline} 0")
# At the same time, run UPDATEs and GC
@@ -55,7 +57,9 @@ async def update_and_gc(env: ZenithEnv, pg: Postgres, timeline: str):
#
def test_gc_aggressive(zenith_simple_env: ZenithEnv):
env = zenith_simple_env
env.zenith_cli.create_branch("test_gc_aggressive", "empty")
# Create a branch for us
env.zenith_cli(["branch", "test_gc_aggressive", "empty"])
pg = env.postgres.create_start('test_gc_aggressive')
log.info('postgres is running on test_gc_aggressive branch')

View File

@@ -1,6 +1,8 @@
from fixtures.zenith_fixtures import ZenithEnv, check_restored_datadir_content
from fixtures.log_helper import log
pytest_plugins = ("fixtures.zenith_fixtures")
#
# Test multixact state after branching
@@ -10,7 +12,8 @@ from fixtures.log_helper import log
#
def test_multixact(zenith_simple_env: ZenithEnv, test_output_dir):
env = zenith_simple_env
env.zenith_cli.create_branch("test_multixact", "empty")
# Create a branch for us
env.zenith_cli(["branch", "test_multixact", "empty"])
pg = env.postgres.create_start('test_multixact')
log.info("postgres is running on 'test_multixact' branch")
@@ -60,7 +63,7 @@ def test_multixact(zenith_simple_env: ZenithEnv, test_output_dir):
assert int(next_multixact_id) > int(next_multixact_id_old)
# Branch at this point
env.zenith_cli.create_branch("test_multixact_new", "test_multixact@" + lsn)
env.zenith_cli(["branch", "test_multixact_new", "test_multixact@" + lsn])
pg_new = env.postgres.create_start('test_multixact_new')
log.info("postgres is running on 'test_multixact_new' branch")

View File

@@ -5,13 +5,15 @@ import time
from fixtures.zenith_fixtures import ZenithEnvBuilder
from fixtures.log_helper import log
pytest_plugins = ("fixtures.zenith_fixtures")
# Test restarting page server, while safekeeper and compute node keep
# running.
def test_next_xid(zenith_env_builder: ZenithEnvBuilder):
# One safekeeper is enough for this test.
zenith_env_builder.num_safekeepers = 1
env = zenith_env_builder.init_start()
env = zenith_env_builder.init()
pg = env.postgres.create_start('main')

View File

@@ -3,6 +3,8 @@ from contextlib import closing
from fixtures.zenith_fixtures import ZenithEnv
from fixtures.log_helper import log
pytest_plugins = ("fixtures.zenith_fixtures")
#
# Test where Postgres generates a lot of WAL, and it's garbage collected away, but
@@ -16,7 +18,8 @@ from fixtures.log_helper import log
#
def test_old_request_lsn(zenith_simple_env: ZenithEnv):
env = zenith_simple_env
env.zenith_cli.create_branch("test_old_request_lsn", "empty")
# Create a branch for us
env.zenith_cli(["branch", "test_old_request_lsn", "empty"])
pg = env.postgres.create_start('test_old_request_lsn')
log.info('postgres is running on test_old_request_lsn branch')
@@ -54,7 +57,7 @@ def test_old_request_lsn(zenith_simple_env: ZenithEnv):
# Make a lot of updates on a single row, generating a lot of WAL. Trigger
# garbage collections so that the page server will remove old page versions.
for i in range(10):
pscur.execute(f"do_gc {env.initial_tenant.hex} {timeline} 0")
pscur.execute(f"do_gc {env.initial_tenant} {timeline} 0")
for j in range(100):
cur.execute('UPDATE foo SET val = val + 1 WHERE id = 1;')

View File

@@ -1,22 +1,95 @@
import json
from uuid import uuid4, UUID
import pytest
from fixtures.zenith_fixtures import ZenithEnv, ZenithEnvBuilder, ZenithPageserverHttpClient, zenith_binpath
import psycopg2
import requests
from fixtures.zenith_fixtures import ZenithEnv, ZenithEnvBuilder, ZenithPageserverHttpClient
from typing import cast
pytest_plugins = ("fixtures.zenith_fixtures")
# test that we cannot override node id
def test_pageserver_init_node_id(zenith_env_builder: ZenithEnvBuilder):
def test_status_psql(zenith_simple_env: ZenithEnv):
env = zenith_simple_env
assert env.pageserver.safe_psql('status') == [
('hello world', ),
]
def test_branch_list_psql(zenith_simple_env: ZenithEnv):
env = zenith_simple_env
# Create a branch for us
env.zenith_cli(["branch", "test_branch_list_main", "empty"])
conn = env.pageserver.connect()
cur = conn.cursor()
cur.execute(f'branch_list {env.initial_tenant}')
branches = json.loads(cur.fetchone()[0])
# Filter out branches created by other tests
branches = [x for x in branches if x['name'].startswith('test_branch_list')]
assert len(branches) == 1
assert branches[0]['name'] == 'test_branch_list_main'
assert 'timeline_id' in branches[0]
assert 'latest_valid_lsn' in branches[0]
assert 'ancestor_id' in branches[0]
assert 'ancestor_lsn' in branches[0]
# Create another branch, and start Postgres on it
env.zenith_cli(['branch', 'test_branch_list_experimental', 'test_branch_list_main'])
env.zenith_cli(['pg', 'create', 'test_branch_list_experimental'])
cur.execute(f'branch_list {env.initial_tenant}')
new_branches = json.loads(cur.fetchone()[0])
# Filter out branches created by other tests
new_branches = [x for x in new_branches if x['name'].startswith('test_branch_list')]
assert len(new_branches) == 2
new_branches.sort(key=lambda k: k['name'])
assert new_branches[0]['name'] == 'test_branch_list_experimental'
assert new_branches[0]['timeline_id'] != branches[0]['timeline_id']
# TODO: do the LSNs have to match here?
assert new_branches[1] == branches[0]
conn.close()
def test_tenant_list_psql(zenith_env_builder: ZenithEnvBuilder):
# don't use zenith_simple_env, because there might be other tenants there,
# left over from other tests.
env = zenith_env_builder.init()
with pytest.raises(
Exception,
match="node id can only be set during pageserver init and cannot be overridden"):
env.pageserver.start(overrides=['--pageserver-config-override=id=10'])
res = env.zenith_cli(["tenant", "list"])
res.check_returncode()
tenants = sorted(map(lambda t: t.split()[0], res.stdout.splitlines()))
assert tenants == [env.initial_tenant]
conn = env.pageserver.connect()
cur = conn.cursor()
# check same tenant cannot be created twice
with pytest.raises(psycopg2.DatabaseError,
match=f'repo for {env.initial_tenant} already exists'):
cur.execute(f'tenant_create {env.initial_tenant}')
# create one more tenant
tenant1 = uuid4().hex
cur.execute(f'tenant_create {tenant1}')
cur.execute('tenant_list')
# compare tenants list
new_tenants = sorted(map(lambda t: cast(str, t['id']), json.loads(cur.fetchone()[0])))
assert sorted([env.initial_tenant, tenant1]) == new_tenants
def check_client(client: ZenithPageserverHttpClient, initial_tenant: UUID):
def check_client(client: ZenithPageserverHttpClient, initial_tenant: str):
client.check_status()
# check initial tenant is there
assert initial_tenant.hex in {t['id'] for t in client.tenant_list()}
assert initial_tenant in {t['id'] for t in client.tenant_list()}
# create new tenant and check it is also there
tenant_id = uuid4()
@@ -48,7 +121,7 @@ def test_pageserver_http_api_client(zenith_simple_env: ZenithEnv):
def test_pageserver_http_api_client_auth_enabled(zenith_env_builder: ZenithEnvBuilder):
zenith_env_builder.pageserver_auth_enabled = True
env = zenith_env_builder.init_start()
env = zenith_env_builder.init()
management_token = env.auth_keys.generate_management_token()

View File

@@ -7,6 +7,8 @@ from multiprocessing import Process, Value
from fixtures.zenith_fixtures import ZenithEnvBuilder
from fixtures.log_helper import log
pytest_plugins = ("fixtures.zenith_fixtures")
# Test safekeeper sync and pageserver catch up
# while initial compute node is down and pageserver is lagging behind safekeepers.
@@ -14,9 +16,9 @@ from fixtures.log_helper import log
# and new compute node contains all data.
def test_pageserver_catchup_while_compute_down(zenith_env_builder: ZenithEnvBuilder):
zenith_env_builder.num_safekeepers = 3
env = zenith_env_builder.init_start()
env = zenith_env_builder.init()
env.zenith_cli.create_branch("test_pageserver_catchup_while_compute_down", "main")
env.zenith_cli(["branch", "test_pageserver_catchup_while_compute_down", "main"])
pg = env.postgres.create_start('test_pageserver_catchup_while_compute_down')
pg_conn = pg.connect()

View File

@@ -7,15 +7,17 @@ from multiprocessing import Process, Value
from fixtures.zenith_fixtures import ZenithEnvBuilder
from fixtures.log_helper import log
pytest_plugins = ("fixtures.zenith_fixtures")
# Test restarting page server, while safekeeper and compute node keep
# running.
def test_pageserver_restart(zenith_env_builder: ZenithEnvBuilder):
# One safekeeper is enough for this test.
zenith_env_builder.num_safekeepers = 1
env = zenith_env_builder.init_start()
env = zenith_env_builder.init()
env.zenith_cli.create_branch("test_pageserver_restart", "main")
env.zenith_cli(["branch", "test_pageserver_restart", "main"])
pg = env.postgres.create_start('test_pageserver_restart')
pg_conn = pg.connect()

View File

@@ -5,6 +5,8 @@ import subprocess
from fixtures.zenith_fixtures import ZenithEnv, Postgres
from fixtures.log_helper import log
pytest_plugins = ("fixtures.zenith_fixtures")
async def repeat_bytes(buf, repetitions: int):
for i in range(repetitions):
@@ -37,7 +39,9 @@ async def parallel_load_same_table(pg: Postgres, n_parallel: int):
# Load data into one table with COPY TO from 5 parallel connections
def test_parallel_copy(zenith_simple_env: ZenithEnv, n_parallel=5):
env = zenith_simple_env
env.zenith_cli.create_branch("test_parallel_copy", "empty")
# Create a branch for us
env.zenith_cli(["branch", "test_parallel_copy", "empty"])
pg = env.postgres.create_start('test_parallel_copy')
log.info("postgres is running on 'test_parallel_copy' branch")

View File

@@ -1,10 +1,14 @@
from fixtures.zenith_fixtures import ZenithEnv
from fixtures.log_helper import log
pytest_plugins = ("fixtures.zenith_fixtures")
def test_pgbench(zenith_simple_env: ZenithEnv, pg_bin):
env = zenith_simple_env
env.zenith_cli.create_branch("test_pgbench", "empty")
# Create a branch for us
env.zenith_cli(["branch", "test_pgbench", "empty"])
pg = env.postgres.create_start('test_pgbench')
log.info("postgres is running on 'test_pgbench' branch")

View File

@@ -1,15 +0,0 @@
import pytest
def test_proxy_select_1(static_proxy):
static_proxy.safe_psql("select 1;")
@pytest.mark.xfail # Proxy eats the extra connection options
def test_proxy_options(static_proxy):
schema_name = "tmp_schema_1"
with static_proxy.connect(schema=schema_name) as conn:
with conn.cursor() as cur:
cur.execute("SHOW search_path;")
search_path = cur.fetchall()[0][0]
assert schema_name == search_path

View File

@@ -2,6 +2,8 @@ import pytest
from fixtures.log_helper import log
from fixtures.zenith_fixtures import ZenithEnv
pytest_plugins = ("fixtures.zenith_fixtures")
#
# Create read-only compute nodes, anchored at historical points in time.
@@ -11,7 +13,7 @@ from fixtures.zenith_fixtures import ZenithEnv
#
def test_readonly_node(zenith_simple_env: ZenithEnv):
env = zenith_simple_env
env.zenith_cli.create_branch("test_readonly_node", "empty")
env.zenith_cli(["branch", "test_readonly_node", "empty"])
pgmain = env.postgres.create_start('test_readonly_node')
log.info("postgres is running on 'test_readonly_node' branch")
@@ -86,5 +88,4 @@ def test_readonly_node(zenith_simple_env: ZenithEnv):
# Create node at pre-initdb lsn
with pytest.raises(Exception, match="invalid basebackup lsn"):
# compute node startup with invalid LSN should fail
env.zenith_cli.pg_start("test_readonly_node_preinitdb",
timeline_spec="test_readonly_node@0/42")
env.zenith_cli(["pg", "start", "test_readonly_node_preinitdb", "test_readonly_node@0/42"])

View File

@@ -9,6 +9,8 @@ from fixtures.zenith_fixtures import ZenithEnvBuilder
from fixtures.log_helper import log
import pytest
pytest_plugins = ("fixtures.zenith_fixtures")
#
# Tests that a piece of data is backed up and restored correctly:
@@ -42,7 +44,7 @@ def test_remote_storage_backup_and_restore(zenith_env_builder: ZenithEnvBuilder,
data_secret = 'very secret secret'
##### First start, insert secret data and upload it to the remote storage
env = zenith_env_builder.init_start()
env = zenith_env_builder.init()
pg = env.postgres.create_start()
tenant_id = pg.safe_psql("show zenith.zenith_tenant")[0][0]

View File

@@ -4,6 +4,8 @@ from contextlib import closing
from fixtures.zenith_fixtures import ZenithEnvBuilder
from fixtures.log_helper import log
pytest_plugins = ("fixtures.zenith_fixtures")
#
# Test restarting and recreating a postgres instance
@@ -13,9 +15,9 @@ def test_restart_compute(zenith_env_builder: ZenithEnvBuilder, with_wal_acceptor
zenith_env_builder.pageserver_auth_enabled = True
if with_wal_acceptors:
zenith_env_builder.num_safekeepers = 3
env = zenith_env_builder.init_start()
env = zenith_env_builder.init()
env.zenith_cli.create_branch("test_restart_compute", "main")
env.zenith_cli(["branch", "test_restart_compute", "main"])
pg = env.postgres.create_start('test_restart_compute')
log.info("postgres is running on 'test_restart_compute' branch")

View File

@@ -5,6 +5,8 @@ from fixtures.utils import print_gc_result
from fixtures.zenith_fixtures import ZenithEnv
from fixtures.log_helper import log
pytest_plugins = ("fixtures.zenith_fixtures")
#
# Test Garbage Collection of old layer files
@@ -14,7 +16,7 @@ from fixtures.log_helper import log
#
def test_layerfiles_gc(zenith_simple_env: ZenithEnv):
env = zenith_simple_env
env.zenith_cli.create_branch("test_layerfiles_gc", "empty")
env.zenith_cli(["branch", "test_layerfiles_gc", "empty"])
pg = env.postgres.create_start('test_layerfiles_gc')
with closing(pg.connect()) as conn:
@@ -48,7 +50,7 @@ def test_layerfiles_gc(zenith_simple_env: ZenithEnv):
cur.execute("DELETE FROM foo")
log.info("Running GC before test")
pscur.execute(f"do_gc {env.initial_tenant.hex} {timeline} 0")
pscur.execute(f"do_gc {env.initial_tenant} {timeline} 0")
row = pscur.fetchone()
print_gc_result(row)
# remember the number of files
@@ -61,7 +63,7 @@ def test_layerfiles_gc(zenith_simple_env: ZenithEnv):
# removing the old image and delta layer.
log.info("Inserting one row and running GC")
cur.execute("INSERT INTO foo VALUES (1)")
pscur.execute(f"do_gc {env.initial_tenant.hex} {timeline} 0")
pscur.execute(f"do_gc {env.initial_tenant} {timeline} 0")
row = pscur.fetchone()
print_gc_result(row)
assert row['layer_relfiles_total'] == layer_relfiles_remain + 2
@@ -75,7 +77,7 @@ def test_layerfiles_gc(zenith_simple_env: ZenithEnv):
cur.execute("INSERT INTO foo VALUES (2)")
cur.execute("INSERT INTO foo VALUES (3)")
pscur.execute(f"do_gc {env.initial_tenant.hex} {timeline} 0")
pscur.execute(f"do_gc {env.initial_tenant} {timeline} 0")
row = pscur.fetchone()
print_gc_result(row)
assert row['layer_relfiles_total'] == layer_relfiles_remain + 2
@@ -87,7 +89,7 @@ def test_layerfiles_gc(zenith_simple_env: ZenithEnv):
cur.execute("INSERT INTO foo VALUES (2)")
cur.execute("INSERT INTO foo VALUES (3)")
pscur.execute(f"do_gc {env.initial_tenant.hex} {timeline} 0")
pscur.execute(f"do_gc {env.initial_tenant} {timeline} 0")
row = pscur.fetchone()
print_gc_result(row)
assert row['layer_relfiles_total'] == layer_relfiles_remain + 2
@@ -96,7 +98,7 @@ def test_layerfiles_gc(zenith_simple_env: ZenithEnv):
# Run GC again, with no changes in the database. Should not remove anything.
log.info("Run GC again, with nothing to do")
pscur.execute(f"do_gc {env.initial_tenant.hex} {timeline} 0")
pscur.execute(f"do_gc {env.initial_tenant} {timeline} 0")
row = pscur.fetchone()
print_gc_result(row)
assert row['layer_relfiles_total'] == layer_relfiles_remain
@@ -109,7 +111,7 @@ def test_layerfiles_gc(zenith_simple_env: ZenithEnv):
log.info("Drop table and run GC again")
cur.execute("DROP TABLE foo")
pscur.execute(f"do_gc {env.initial_tenant.hex} {timeline} 0")
pscur.execute(f"do_gc {env.initial_tenant} {timeline} 0")
row = pscur.fetchone()
print_gc_result(row)

View File

@@ -1,6 +1,8 @@
from fixtures.zenith_fixtures import ZenithEnv, check_restored_datadir_content
from fixtures.log_helper import log
pytest_plugins = ("fixtures.zenith_fixtures")
# Test subtransactions
#
@@ -10,7 +12,8 @@ from fixtures.log_helper import log
# CLOG.
def test_subxacts(zenith_simple_env: ZenithEnv, test_output_dir):
env = zenith_simple_env
env.zenith_cli.create_branch("test_subxacts", "empty")
# Create a branch for us
env.zenith_cli(["branch", "test_subxacts", "empty"])
pg = env.postgres.create_start('test_subxacts')
log.info("postgres is running on 'test_subxacts' branch")

View File

@@ -108,8 +108,8 @@ def load(pg: Postgres, stop_event: threading.Event, load_ok_event: threading.Eve
log.info('load thread stopped')
def assert_local(pageserver_http_client: ZenithPageserverHttpClient, tenant: UUID, timeline: str):
timeline_detail = pageserver_http_client.timeline_detail(tenant, UUID(timeline))
def assert_local(pageserver_http_client: ZenithPageserverHttpClient, tenant: str, timeline: str):
timeline_detail = pageserver_http_client.timeline_detail(UUID(tenant), UUID(timeline))
assert timeline_detail.get('type') == "Local", timeline_detail
return timeline_detail
@@ -122,15 +122,15 @@ def test_tenant_relocation(zenith_env_builder: ZenithEnvBuilder,
zenith_env_builder.num_safekeepers = 1
zenith_env_builder.enable_local_fs_remote_storage()
env = zenith_env_builder.init_start()
env = zenith_env_builder.init()
# create folder for remote storage mock
remote_storage_mock_path = env.repo_dir / 'local_fs_remote_storage'
tenant = env.create_tenant(UUID("74ee8b079a0e437eb0afea7d26a07209"))
tenant = env.create_tenant("74ee8b079a0e437eb0afea7d26a07209")
log.info("tenant to relocate %s", tenant)
env.zenith_cli.create_branch("test_tenant_relocation", "main", tenant_id=tenant)
env.zenith_cli(["branch", "test_tenant_relocation", "main", f"--tenantid={tenant}"])
tenant_pg = env.postgres.create_start(
"test_tenant_relocation",
@@ -167,11 +167,11 @@ def test_tenant_relocation(zenith_env_builder: ZenithEnvBuilder,
# run checkpoint manually to be sure that data landed in remote storage
with closing(env.pageserver.connect()) as psconn:
with psconn.cursor() as pscur:
pscur.execute(f"do_gc {tenant.hex} {timeline}")
pscur.execute(f"do_gc {tenant} {timeline}")
# ensure upload is completed
pageserver_http_client = env.pageserver.http_client()
timeline_detail = pageserver_http_client.timeline_detail(tenant, UUID(timeline))
timeline_detail = pageserver_http_client.timeline_detail(UUID(tenant), UUID(timeline))
assert timeline_detail['disk_consistent_lsn'] == timeline_detail['timeline_state']['Ready']
log.info("inititalizing new pageserver")
@@ -194,7 +194,7 @@ def test_tenant_relocation(zenith_env_builder: ZenithEnvBuilder,
new_pageserver_http_port):
# call to attach timeline to new pageserver
new_pageserver_http_client.timeline_attach(tenant, UUID(timeline))
new_pageserver_http_client.timeline_attach(UUID(tenant), UUID(timeline))
# FIXME cannot handle duplicate download requests, subject to fix in https://github.com/zenithdb/zenith/issues/997
time.sleep(5)
# new pageserver should in sync (modulo wal tail or vacuum activity) with the old one because there was no new writes since checkpoint
@@ -241,7 +241,7 @@ def test_tenant_relocation(zenith_env_builder: ZenithEnvBuilder,
# detach tenant from old pageserver before we check
# that all the data is there to be sure that old pageserver
# is no longer involved, and if it is, we will see the errors
pageserver_http_client.timeline_detach(tenant, UUID(timeline))
pageserver_http_client.timeline_detach(UUID(tenant), UUID(timeline))
with pg_cur(tenant_pg) as cur:
# check that data is still there

View File

@@ -10,17 +10,23 @@ def test_tenants_normal_work(zenith_env_builder: ZenithEnvBuilder, with_wal_acce
if with_wal_acceptors:
zenith_env_builder.num_safekeepers = 3
env = zenith_env_builder.init_start()
env = zenith_env_builder.init()
"""Tests tenants with and without wal acceptors"""
tenant_1 = env.create_tenant()
tenant_2 = env.create_tenant()
env.zenith_cli.create_branch(f"test_tenants_normal_work_with_wal_acceptors{with_wal_acceptors}",
"main",
tenant_id=tenant_1)
env.zenith_cli.create_branch(f"test_tenants_normal_work_with_wal_acceptors{with_wal_acceptors}",
"main",
tenant_id=tenant_2)
env.zenith_cli([
"branch",
f"test_tenants_normal_work_with_wal_acceptors{with_wal_acceptors}",
"main",
f"--tenantid={tenant_1}"
])
env.zenith_cli([
"branch",
f"test_tenants_normal_work_with_wal_acceptors{with_wal_acceptors}",
"main",
f"--tenantid={tenant_2}"
])
pg_tenant1 = env.postgres.create_start(
f"test_tenants_normal_work_with_wal_acceptors{with_wal_acceptors}",

View File

@@ -10,10 +10,10 @@ import time
def test_timeline_size(zenith_simple_env: ZenithEnv):
env = zenith_simple_env
# Branch at the point where only 100 rows were inserted
env.zenith_cli.create_branch("test_timeline_size", "empty")
env.zenith_cli(["branch", "test_timeline_size", "empty"])
client = env.pageserver.http_client()
res = client.branch_detail(env.initial_tenant, "test_timeline_size")
res = client.branch_detail(UUID(env.initial_tenant), "test_timeline_size")
assert res["current_logical_size"] == res["current_logical_size_non_incremental"]
pgmain = env.postgres.create_start("test_timeline_size")
@@ -31,47 +31,47 @@ def test_timeline_size(zenith_simple_env: ZenithEnv):
FROM generate_series(1, 10) g
""")
res = client.branch_detail(env.initial_tenant, "test_timeline_size")
res = client.branch_detail(UUID(env.initial_tenant), "test_timeline_size")
assert res["current_logical_size"] == res["current_logical_size_non_incremental"]
cur.execute("TRUNCATE foo")
res = client.branch_detail(env.initial_tenant, "test_timeline_size")
res = client.branch_detail(UUID(env.initial_tenant), "test_timeline_size")
assert res["current_logical_size"] == res["current_logical_size_non_incremental"]
# wait until received_lsn_lag is 0
# wait until write_lag is 0
def wait_for_pageserver_catchup(pgmain: Postgres, polling_interval=1, timeout=60):
started_at = time.time()
received_lsn_lag = 1
while received_lsn_lag > 0:
write_lag = 1
while write_lag > 0:
elapsed = time.time() - started_at
if elapsed > timeout:
raise RuntimeError(
f"timed out waiting for pageserver to reach pg_current_wal_flush_lsn()")
raise RuntimeError(f"timed out waiting for pageserver to reach pg_current_wal_lsn()")
with closing(pgmain.connect()) as conn:
with conn.cursor() as cur:
cur.execute('''
select pg_size_pretty(pg_cluster_size()),
pg_wal_lsn_diff(pg_current_wal_flush_lsn(),received_lsn) as received_lsn_lag
FROM backpressure_lsns();
pg_wal_lsn_diff(pg_current_wal_lsn(),write_lsn) as write_lag,
pg_wal_lsn_diff(pg_current_wal_lsn(),sent_lsn) as pending_lag
FROM pg_stat_get_wal_senders();
''')
res = cur.fetchone()
log.info(f"pg_cluster_size = {res[0]}, received_lsn_lag = {res[1]}")
received_lsn_lag = res[1]
log.info(
f"pg_cluster_size = {res[0]}, write_lag = {res[1]}, pending_lag = {res[2]}")
write_lag = res[1]
time.sleep(polling_interval)
def test_timeline_size_quota(zenith_env_builder: ZenithEnvBuilder):
zenith_env_builder.num_safekeepers = 1
env = zenith_env_builder.init_start()
env.zenith_cli.create_branch("test_timeline_size_quota", "main")
env = zenith_env_builder.init()
env.zenith_cli(["branch", "test_timeline_size_quota", "main"])
client = env.pageserver.http_client()
res = client.branch_detail(env.initial_tenant, "test_timeline_size_quota")
res = client.branch_detail(UUID(env.initial_tenant), "test_timeline_size_quota")
assert res["current_logical_size"] == res["current_logical_size_non_incremental"]
pgmain = env.postgres.create_start(

View File

@@ -3,13 +3,15 @@ import os
from fixtures.zenith_fixtures import ZenithEnv
from fixtures.log_helper import log
pytest_plugins = ("fixtures.zenith_fixtures")
#
# Test branching, when a transaction is in prepared state
#
def test_twophase(zenith_simple_env: ZenithEnv):
env = zenith_simple_env
env.zenith_cli.create_branch("test_twophase", "empty")
env.zenith_cli(["branch", "test_twophase", "empty"])
pg = env.postgres.create_start('test_twophase', config_lines=['max_prepared_transactions=5'])
log.info("postgres is running on 'test_twophase' branch")
@@ -56,7 +58,7 @@ def test_twophase(zenith_simple_env: ZenithEnv):
assert len(twophase_files) == 2
# Create a branch with the transaction in prepared state
env.zenith_cli.create_branch("test_twophase_prepared", "test_twophase")
env.zenith_cli(["branch", "test_twophase_prepared", "test_twophase"])
# Start compute on the new branch
pg2 = env.postgres.create_start(
@@ -78,8 +80,8 @@ def test_twophase(zenith_simple_env: ZenithEnv):
cur2.execute("ROLLBACK PREPARED 'insert_two'")
cur2.execute('SELECT * FROM foo')
assert cur2.fetchall() == [('one', ), ('three', )]
assert cur2.fetchall() == [('one', ), ('three', )] # type: ignore[comparison-overlap]
# Only one committed insert is visible on the original branch
cur.execute('SELECT * FROM foo')
assert cur.fetchall() == [('three', )]
assert cur.fetchall() == [('three', )] # type: ignore[comparison-overlap]

View File

@@ -1,6 +1,8 @@
from fixtures.zenith_fixtures import ZenithEnv
from fixtures.log_helper import log
pytest_plugins = ("fixtures.zenith_fixtures")
#
# Test that the VM bit is cleared correctly at a HEAP_DELETE and
@@ -9,7 +11,8 @@ from fixtures.log_helper import log
def test_vm_bit_clear(zenith_simple_env: ZenithEnv):
env = zenith_simple_env
env.zenith_cli.create_branch("test_vm_bit_clear", "empty")
# Create a branch for us
env.zenith_cli(["branch", "test_vm_bit_clear", "empty"])
pg = env.postgres.create_start('test_vm_bit_clear')
log.info("postgres is running on 'test_vm_bit_clear' branch")
@@ -33,7 +36,7 @@ def test_vm_bit_clear(zenith_simple_env: ZenithEnv):
cur.execute('UPDATE vmtest_update SET id = 5000 WHERE id = 1')
# Branch at this point, to test that later
env.zenith_cli.create_branch("test_vm_bit_clear_new", "test_vm_bit_clear")
env.zenith_cli(["branch", "test_vm_bit_clear_new", "test_vm_bit_clear"])
# Clear the buffer cache, to force the VM page to be re-fetched from
# the page server

View File

@@ -17,14 +17,16 @@ from fixtures.utils import lsn_to_hex, mkdir_if_needed
from fixtures.log_helper import log
from typing import List, Optional, Any
pytest_plugins = ("fixtures.zenith_fixtures")
# basic test, write something in setup with wal acceptors, ensure that commits
# succeed and data is written
def test_normal_work(zenith_env_builder: ZenithEnvBuilder):
zenith_env_builder.num_safekeepers = 3
env = zenith_env_builder.init_start()
env = zenith_env_builder.init()
env.zenith_cli.create_branch("test_wal_acceptors_normal_work", "main")
env.zenith_cli(["branch", "test_wal_acceptors_normal_work", "main"])
pg = env.postgres.create_start('test_wal_acceptors_normal_work')
@@ -51,7 +53,7 @@ class BranchMetrics:
# against different timelines.
def test_many_timelines(zenith_env_builder: ZenithEnvBuilder):
zenith_env_builder.num_safekeepers = 3
env = zenith_env_builder.init_start()
env = zenith_env_builder.init()
n_timelines = 3
@@ -60,10 +62,10 @@ def test_many_timelines(zenith_env_builder: ZenithEnvBuilder):
# start postgres on each timeline
pgs = []
for branch in branches:
env.zenith_cli.create_branch(branch, "main")
env.zenith_cli(["branch", branch, "main"])
pgs.append(env.postgres.create_start(branch))
tenant_id = env.initial_tenant
tenant_id = uuid.UUID(env.initial_tenant)
def collect_metrics(message: str) -> List[BranchMetrics]:
with env.pageserver.http_client() as pageserver_http:
@@ -90,8 +92,8 @@ def test_many_timelines(zenith_env_builder: ZenithEnvBuilder):
latest_valid_lsn=branch_detail["latest_valid_lsn"],
)
for sk_m in sk_metrics:
m.flush_lsns.append(sk_m.flush_lsn_inexact[(tenant_id.hex, timeline_id)])
m.commit_lsns.append(sk_m.commit_lsn_inexact[(tenant_id.hex, timeline_id)])
m.flush_lsns.append(sk_m.flush_lsn_inexact[timeline_id])
m.commit_lsns.append(sk_m.commit_lsn_inexact[timeline_id])
for flush_lsn, commit_lsn in zip(m.flush_lsns, m.commit_lsns):
# Invariant. May be < when transaction is in progress.
@@ -181,9 +183,9 @@ def test_restarts(zenith_env_builder: ZenithEnvBuilder):
n_acceptors = 3
zenith_env_builder.num_safekeepers = n_acceptors
env = zenith_env_builder.init_start()
env = zenith_env_builder.init()
env.zenith_cli.create_branch("test_wal_acceptors_restarts", "main")
env.zenith_cli(["branch", "test_wal_acceptors_restarts", "main"])
pg = env.postgres.create_start('test_wal_acceptors_restarts')
# we rely upon autocommit after each statement
@@ -218,9 +220,9 @@ def delayed_wal_acceptor_start(wa):
# When majority of acceptors is offline, commits are expected to be frozen
def test_unavailability(zenith_env_builder: ZenithEnvBuilder):
zenith_env_builder.num_safekeepers = 2
env = zenith_env_builder.init_start()
env = zenith_env_builder.init()
env.zenith_cli.create_branch("test_wal_acceptors_unavailability", "main")
env.zenith_cli(["branch", "test_wal_acceptors_unavailability", "main"])
pg = env.postgres.create_start('test_wal_acceptors_unavailability')
# we rely upon autocommit after each statement
@@ -289,9 +291,9 @@ def stop_value():
def test_race_conditions(zenith_env_builder: ZenithEnvBuilder, stop_value):
zenith_env_builder.num_safekeepers = 3
env = zenith_env_builder.init_start()
env = zenith_env_builder.init()
env.zenith_cli.create_branch("test_wal_acceptors_race_conditions", "main")
env.zenith_cli(["branch", "test_wal_acceptors_race_conditions", "main"])
pg = env.postgres.create_start('test_wal_acceptors_race_conditions')
# we rely upon autocommit after each statement
@@ -319,16 +321,16 @@ class ProposerPostgres(PgProtocol):
def __init__(self,
pgdata_dir: str,
pg_bin,
timeline_id: uuid.UUID,
tenant_id: uuid.UUID,
timeline_id: str,
tenant_id: str,
listen_addr: str,
port: int):
super().__init__(host=listen_addr, port=port, username='zenith_admin')
self.pgdata_dir: str = pgdata_dir
self.pg_bin: PgBin = pg_bin
self.timeline_id: uuid.UUID = timeline_id
self.tenant_id: uuid.UUID = tenant_id
self.timeline_id: str = timeline_id
self.tenant_id: str = tenant_id
self.listen_addr: str = listen_addr
self.port: int = port
@@ -348,8 +350,8 @@ class ProposerPostgres(PgProtocol):
cfg = [
"synchronous_standby_names = 'walproposer'\n",
"shared_preload_libraries = 'zenith'\n",
f"zenith.zenith_timeline = '{self.timeline_id.hex}'\n",
f"zenith.zenith_tenant = '{self.tenant_id.hex}'\n",
f"zenith.zenith_timeline = '{self.timeline_id}'\n",
f"zenith.zenith_tenant = '{self.tenant_id}'\n",
f"zenith.page_server_connstring = ''\n",
f"wal_acceptors = '{wal_acceptors}'\n",
f"listen_addresses = '{self.listen_addr}'\n",
@@ -404,10 +406,10 @@ def test_sync_safekeepers(zenith_env_builder: ZenithEnvBuilder,
# We don't really need the full environment for this test, just the
# safekeepers would be enough.
zenith_env_builder.num_safekeepers = 3
env = zenith_env_builder.init_start()
env = zenith_env_builder.init()
timeline_id = uuid.uuid4()
tenant_id = uuid.uuid4()
timeline_id = uuid.uuid4().hex
tenant_id = uuid.uuid4().hex
# write config for proposer
pgdata_dir = os.path.join(env.repo_dir, "proposer_pgdata")
@@ -454,9 +456,9 @@ def test_sync_safekeepers(zenith_env_builder: ZenithEnvBuilder,
def test_timeline_status(zenith_env_builder: ZenithEnvBuilder):
zenith_env_builder.num_safekeepers = 1
env = zenith_env_builder.init_start()
env = zenith_env_builder.init()
env.zenith_cli.create_branch("test_timeline_status", "main")
env.zenith_cli(["branch", "test_timeline_status", "main"])
pg = env.postgres.create_start('test_timeline_status')
wa = env.safekeepers[0]
@@ -493,15 +495,15 @@ class SafekeeperEnv:
self.bin_safekeeper = os.path.join(str(zenith_binpath), 'safekeeper')
self.safekeepers: Optional[List[subprocess.CompletedProcess[Any]]] = None
self.postgres: Optional[ProposerPostgres] = None
self.tenant_id: Optional[uuid.UUID] = None
self.timeline_id: Optional[uuid.UUID] = None
self.tenant_id: Optional[str] = None
self.timeline_id: Optional[str] = None
def init(self) -> "SafekeeperEnv":
assert self.postgres is None, "postgres is already initialized"
assert self.safekeepers is None, "safekeepers are already initialized"
self.timeline_id = uuid.uuid4()
self.tenant_id = uuid.uuid4()
self.timeline_id = uuid.uuid4().hex
self.tenant_id = uuid.uuid4().hex
mkdir_if_needed(str(self.repo_dir))
# Create config and a Safekeeper object for each safekeeper
@@ -521,7 +523,12 @@ class SafekeeperEnv:
http=self.port_distributor.get_port(),
)
safekeeper_dir = os.path.join(self.repo_dir, f"sk{i}")
if self.num_safekeepers == 1:
name = "single"
else:
name = f"sk{i}"
safekeeper_dir = os.path.join(self.repo_dir, name)
mkdir_if_needed(safekeeper_dir)
args = [
@@ -532,8 +539,6 @@ class SafekeeperEnv:
f"127.0.0.1:{port.http}",
"-D",
safekeeper_dir,
"--id",
str(i),
"--daemonize"
]
@@ -601,8 +606,9 @@ def test_safekeeper_without_pageserver(test_output_dir: str,
def test_replace_safekeeper(zenith_env_builder: ZenithEnvBuilder):
def safekeepers_guc(env: ZenithEnv, sk_names: List[int]) -> str:
return ','.join([f'localhost:{sk.port.pg}' for sk in env.safekeepers if sk.id in sk_names])
def safekeepers_guc(env: ZenithEnv, sk_names: List[str]) -> str:
return ','.join(
[f'localhost:{sk.port.pg}' for sk in env.safekeepers if sk.name in sk_names])
def execute_payload(pg: Postgres):
with closing(pg.connect()) as conn:
@@ -624,17 +630,17 @@ def test_replace_safekeeper(zenith_env_builder: ZenithEnvBuilder):
http_cli = sk.http_client()
try:
status = http_cli.timeline_status(tenant_id, timeline_id)
log.info(f"Safekeeper {sk.id} status: {status}")
log.info(f"Safekeeper {sk.name} status: {status}")
except Exception as e:
log.info(f"Safekeeper {sk.id} status error: {e}")
log.info(f"Safekeeper {sk.name} status error: {e}")
zenith_env_builder.num_safekeepers = 4
env = zenith_env_builder.init_start()
env.zenith_cli.create_branch("test_replace_safekeeper", "main")
env = zenith_env_builder.init()
env.zenith_cli(["branch", "test_replace_safekeeper", "main"])
log.info("Use only first 3 safekeepers")
env.safekeepers[3].stop()
active_safekeepers = [1, 2, 3]
active_safekeepers = ['sk1', 'sk2', 'sk3']
pg = env.postgres.create('test_replace_safekeeper')
pg.adjust_for_wal_acceptors(safekeepers_guc(env, active_safekeepers))
pg.start()
@@ -674,7 +680,7 @@ def test_replace_safekeeper(zenith_env_builder: ZenithEnvBuilder):
log.info("Recreate postgres to replace failed sk1 with new sk4")
pg.stop_and_destroy().create('test_replace_safekeeper')
active_safekeepers = [2, 3, 4]
active_safekeepers = ['sk2', 'sk3', 'sk4']
env.safekeepers[3].start()
pg.adjust_for_wal_acceptors(safekeepers_guc(env, active_safekeepers))
pg.start()

View File

@@ -9,6 +9,7 @@ from fixtures.utils import lsn_from_hex, lsn_to_hex
from typing import List
log = getLogger('root.wal_acceptor_async')
pytest_plugins = ("fixtures.zenith_fixtures")
class BankClient(object):
@@ -200,9 +201,9 @@ async def run_restarts_under_load(pg: Postgres, acceptors: List[Safekeeper], n_w
# restart acceptors one by one, while executing and validating bank transactions
def test_restarts_under_load(zenith_env_builder: ZenithEnvBuilder):
zenith_env_builder.num_safekeepers = 3
env = zenith_env_builder.init_start()
env = zenith_env_builder.init()
env.zenith_cli.create_branch("test_wal_acceptors_restarts_under_load", "main")
env.zenith_cli(["branch", "test_wal_acceptors_restarts_under_load", "main"])
pg = env.postgres.create_start('test_wal_acceptors_restarts_under_load')
asyncio.run(run_restarts_under_load(pg, env.safekeepers))

View File

@@ -3,26 +3,30 @@ import uuid
import requests
from psycopg2.extensions import cursor as PgCursor
from fixtures.zenith_fixtures import ZenithEnv, ZenithEnvBuilder, ZenithPageserverHttpClient
from fixtures.zenith_fixtures import ZenithEnv, ZenithEnvBuilder
from typing import cast
pytest_plugins = ("fixtures.zenith_fixtures")
def helper_compare_branch_list(pageserver_http_client: ZenithPageserverHttpClient,
env: ZenithEnv,
initial_tenant: uuid.UUID):
def helper_compare_branch_list(page_server_cur: PgCursor, env: ZenithEnv, initial_tenant: str):
"""
Compare branches list returned by CLI and directly via API.
Filters out branches created by other tests.
"""
branches = pageserver_http_client.branch_list(initial_tenant)
branches_api = sorted(map(lambda b: cast(str, b['name']), branches))
page_server_cur.execute(f'branch_list {initial_tenant}')
branches_api = sorted(
map(lambda b: cast(str, b['name']), json.loads(page_server_cur.fetchone()[0])))
branches_api = [b for b in branches_api if b.startswith('test_cli_') or b in ('empty', 'main')]
res = env.zenith_cli.list_branches()
res = env.zenith_cli(["branch"])
res.check_returncode()
branches_cli = sorted(map(lambda b: b.split(':')[-1].strip(), res.stdout.strip().split("\n")))
branches_cli = [b for b in branches_cli if b.startswith('test_cli_') or b in ('empty', 'main')]
res = env.zenith_cli.list_branches(tenant_id=initial_tenant)
res = env.zenith_cli(["branch", f"--tenantid={initial_tenant}"])
res.check_returncode()
branches_cli_with_tenant_arg = sorted(
map(lambda b: b.split(':')[-1].strip(), res.stdout.strip().split("\n")))
branches_cli_with_tenant_arg = [
@@ -34,20 +38,24 @@ def helper_compare_branch_list(pageserver_http_client: ZenithPageserverHttpClien
def test_cli_branch_list(zenith_simple_env: ZenithEnv):
env = zenith_simple_env
pageserver_http_client = env.pageserver.http_client()
page_server_conn = env.pageserver.connect()
page_server_cur = page_server_conn.cursor()
# Initial sanity check
helper_compare_branch_list(pageserver_http_client, env, env.initial_tenant)
env.zenith_cli.create_branch("test_cli_branch_list_main", "empty")
helper_compare_branch_list(pageserver_http_client, env, env.initial_tenant)
helper_compare_branch_list(page_server_cur, env, env.initial_tenant)
# Create a branch for us
res = env.zenith_cli(["branch", "test_cli_branch_list_main", "empty"])
assert res.stderr == ''
helper_compare_branch_list(page_server_cur, env, env.initial_tenant)
# Create a nested branch
res = env.zenith_cli.create_branch("test_cli_branch_list_nested", "test_cli_branch_list_main")
res = env.zenith_cli(["branch", "test_cli_branch_list_nested", "test_cli_branch_list_main"])
assert res.stderr == ''
helper_compare_branch_list(pageserver_http_client, env, env.initial_tenant)
helper_compare_branch_list(page_server_cur, env, env.initial_tenant)
# Check that all new branches are visible via CLI
res = env.zenith_cli.list_branches()
res = env.zenith_cli(["branch"])
assert res.stderr == ''
branches_cli = sorted(map(lambda b: b.split(':')[-1].strip(), res.stdout.strip().split("\n")))
@@ -55,11 +63,12 @@ def test_cli_branch_list(zenith_simple_env: ZenithEnv):
assert 'test_cli_branch_list_nested' in branches_cli
def helper_compare_tenant_list(pageserver_http_client: ZenithPageserverHttpClient, env: ZenithEnv):
tenants = pageserver_http_client.tenant_list()
tenants_api = sorted(map(lambda t: cast(str, t['id']), tenants))
def helper_compare_tenant_list(page_server_cur: PgCursor, env: ZenithEnv):
page_server_cur.execute(f'tenant_list')
tenants_api = sorted(
map(lambda t: cast(str, t['id']), json.loads(page_server_cur.fetchone()[0])))
res = env.zenith_cli.list_tenants()
res = env.zenith_cli(["tenant", "list"])
assert res.stderr == ''
tenants_cli = sorted(map(lambda t: t.split()[0], res.stdout.splitlines()))
@@ -68,36 +77,41 @@ def helper_compare_tenant_list(pageserver_http_client: ZenithPageserverHttpClien
def test_cli_tenant_list(zenith_simple_env: ZenithEnv):
env = zenith_simple_env
pageserver_http_client = env.pageserver.http_client()
page_server_conn = env.pageserver.connect()
page_server_cur = page_server_conn.cursor()
# Initial sanity check
helper_compare_tenant_list(pageserver_http_client, env)
helper_compare_tenant_list(page_server_cur, env)
# Create new tenant
tenant1 = uuid.uuid4()
env.zenith_cli.create_tenant(tenant1)
tenant1 = uuid.uuid4().hex
res = env.zenith_cli(["tenant", "create", tenant1])
res.check_returncode()
# check tenant1 appeared
helper_compare_tenant_list(pageserver_http_client, env)
helper_compare_tenant_list(page_server_cur, env)
# Create new tenant
tenant2 = uuid.uuid4()
env.zenith_cli.create_tenant(tenant2)
tenant2 = uuid.uuid4().hex
res = env.zenith_cli(["tenant", "create", tenant2])
res.check_returncode()
# check tenant2 appeared
helper_compare_tenant_list(pageserver_http_client, env)
helper_compare_tenant_list(page_server_cur, env)
res = env.zenith_cli.list_tenants()
res = env.zenith_cli(["tenant", "list"])
res.check_returncode()
tenants = sorted(map(lambda t: t.split()[0], res.stdout.splitlines()))
assert env.initial_tenant.hex in tenants
assert tenant1.hex in tenants
assert tenant2.hex in tenants
assert env.initial_tenant in tenants
assert tenant1 in tenants
assert tenant2 in tenants
def test_cli_ipv4_listeners(zenith_env_builder: ZenithEnvBuilder):
# Start with single sk
zenith_env_builder.num_safekeepers = 1
env = zenith_env_builder.init_start()
env = zenith_env_builder.init()
# Connect to sk port on v4 loopback
res = requests.get(f'http://127.0.0.1:{env.safekeepers[0].port.http}/v1/status')
@@ -109,21 +123,3 @@ def test_cli_ipv4_listeners(zenith_env_builder: ZenithEnvBuilder):
# Connect to ps port on v4 loopback
# res = requests.get(f'http://127.0.0.1:{env.pageserver.service_port.http}/v1/status')
# assert res.ok
def test_cli_start_stop(zenith_env_builder: ZenithEnvBuilder):
# Start with single sk
zenith_env_builder.num_safekeepers = 1
env = zenith_env_builder.init_start()
# Stop default ps/sk
env.zenith_cli.pageserver_stop()
env.zenith_cli.safekeeper_stop()
# Default start
res = env.zenith_cli.raw_cli(["start"])
res.check_returncode()
# Default stop
res = env.zenith_cli.raw_cli(["stop"])
res.check_returncode()

View File

@@ -3,11 +3,15 @@ import os
from fixtures.utils import mkdir_if_needed
from fixtures.zenith_fixtures import ZenithEnv, base_dir, pg_distrib_dir
pytest_plugins = ("fixtures.zenith_fixtures")
def test_isolation(zenith_simple_env: ZenithEnv, test_output_dir, pg_bin, capsys):
env = zenith_simple_env
env.zenith_cli.create_branch("test_isolation", "empty")
# Create a branch for us
env.zenith_cli(["branch", "test_isolation", "empty"])
# Connect to postgres and create a database called "regression".
# isolation tests use prepared transactions, so enable them
pg = env.postgres.create_start('test_isolation', config_lines=['max_prepared_transactions=100'])

View File

@@ -3,11 +3,15 @@ import os
from fixtures.utils import mkdir_if_needed
from fixtures.zenith_fixtures import ZenithEnv, check_restored_datadir_content, base_dir, pg_distrib_dir
pytest_plugins = ("fixtures.zenith_fixtures")
def test_pg_regress(zenith_simple_env: ZenithEnv, test_output_dir: str, pg_bin, capsys):
env = zenith_simple_env
env.zenith_cli.create_branch("test_pg_regress", "empty")
# Create a branch for us
env.zenith_cli(["branch", "test_pg_regress", "empty"])
# Connect to postgres and create a database called "regression".
pg = env.postgres.create_start('test_pg_regress')
pg.safe_psql('CREATE DATABASE regression')

View File

@@ -7,11 +7,15 @@ from fixtures.zenith_fixtures import (ZenithEnv,
pg_distrib_dir)
from fixtures.log_helper import log
pytest_plugins = ("fixtures.zenith_fixtures")
def test_zenith_regress(zenith_simple_env: ZenithEnv, test_output_dir, pg_bin, capsys):
env = zenith_simple_env
env.zenith_cli.create_branch("test_zenith_regress", "empty")
# Create a branch for us
env.zenith_cli(["branch", "test_zenith_regress", "empty"])
# Connect to postgres and create a database called "regression".
pg = env.postgres.create_start('test_zenith_regress')
pg.safe_psql('CREATE DATABASE regression')

View File

@@ -1,6 +1 @@
pytest_plugins = (
"fixtures.zenith_fixtures",
"fixtures.benchmark_fixture",
"fixtures.compare_fixtures",
"fixtures.slow",
)
pytest_plugins = ("fixtures.zenith_fixtures", "fixtures.benchmark_fixture")

View File

@@ -8,7 +8,6 @@ import timeit
import calendar
import enum
from datetime import datetime
import uuid
import pytest
from _pytest.config import Config
from _pytest.terminal import TerminalReporter
@@ -27,6 +26,8 @@ bencmark, and then record the result by calling zenbenchmark.record. For example
import timeit
from fixtures.zenith_fixtures import ZenithEnv
pytest_plugins = ("fixtures.zenith_fixtures", "fixtures.benchmark_fixture")
def test_mybench(zenith_simple_env: env, zenbenchmark):
# Initialize the test
@@ -39,8 +40,6 @@ def test_mybench(zenith_simple_env: env, zenbenchmark):
# Record another measurement
zenbenchmark.record('speed_of_light', 300000, 'km/s')
There's no need to import this file to use it. It should be declared as a plugin
inside conftest.py, and that makes it available to all tests.
You can measure multiple things in one test, and record each one with a separate
call to zenbenchmark. For example, you could time the bulk loading that happens
@@ -277,11 +276,11 @@ class ZenithBenchmarker:
assert matches
return int(round(float(matches.group(1))))
def get_timeline_size(self, repo_dir: Path, tenantid: uuid.UUID, timelineid: str):
def get_timeline_size(self, repo_dir: Path, tenantid: str, timelineid: str):
"""
Calculate the on-disk size of a timeline
"""
path = "{}/tenants/{}/timelines/{}".format(repo_dir, tenantid.hex, timelineid)
path = "{}/tenants/{}/timelines/{}".format(repo_dir, tenantid, timelineid)
totalbytes = 0
for root, dirs, files in os.walk(path):

View File

@@ -25,10 +25,6 @@ class PgCompare(ABC):
def pg_bin(self) -> PgBin:
pass
@property
def zenbenchmark(self) -> ZenithBenchmarker:
pass
@abstractmethod
def flush(self) -> None:
pass
@@ -52,42 +48,6 @@ class PgCompare(ABC):
pass
class RemoteCompare(PgCompare):
"""PgCompare interface for testing against a remote zenith deployment."""
def __init__(self, zenbenchmark: ZenithBenchmarker, pg: PgProtocol, pg_bin: PgBin):
self._zenbenchmark = zenbenchmark
self._pg = pg
self._pg_bin = pg_bin
@property
def pg(self):
return self._pg
@property
def zenbenchmark(self):
return self._zenbenchmark
@property
def pg_bin(self):
return self._pg_bin
def flush(self):
# We can't flush because we don't have a direct connection to the pageserver
pass
def report_peak_memory_use(self) -> None:
pass
def report_size(self) -> None:
pass
def record_pageserver_writes(self, out_name):
pass
def record_duration(self, out_name):
return self.zenbenchmark.record_duration(out_name)
class ZenithCompare(PgCompare):
"""PgCompare interface for the zenith stack."""
def __init__(self,
@@ -96,12 +56,12 @@ class ZenithCompare(PgCompare):
pg_bin: PgBin,
branch_name):
self.env = zenith_simple_env
self._zenbenchmark = zenbenchmark
self.zenbenchmark = zenbenchmark
self._pg_bin = pg_bin
# We only use one branch and one timeline
self.branch = branch_name
self.env.zenith_cli.create_branch(self.branch, "empty")
self.env.zenith_cli(["branch", self.branch, "empty"])
self._pg = self.env.postgres.create_start(self.branch)
self.timeline = self.pg.safe_psql("SHOW zenith.zenith_timeline")[0][0]
@@ -113,16 +73,12 @@ class ZenithCompare(PgCompare):
def pg(self):
return self._pg
@property
def zenbenchmark(self):
return self._zenbenchmark
@property
def pg_bin(self):
return self._pg_bin
def flush(self):
self.pscur.execute(f"do_gc {self.env.initial_tenant.hex} {self.timeline} 0")
self.pscur.execute(f"do_gc {self.env.initial_tenant} {self.timeline} 0")
def report_peak_memory_use(self) -> None:
self.zenbenchmark.record("peak_mem",
@@ -150,7 +106,7 @@ class VanillaCompare(PgCompare):
"""PgCompare interface for vanilla postgres."""
def __init__(self, zenbenchmark, vanilla_pg: VanillaPostgres):
self._pg = vanilla_pg
self._zenbenchmark = zenbenchmark
self.zenbenchmark = zenbenchmark
vanilla_pg.configure(['shared_buffers=1MB'])
vanilla_pg.start()
@@ -162,10 +118,6 @@ class VanillaCompare(PgCompare):
def pg(self):
return self._pg
@property
def zenbenchmark(self):
return self._zenbenchmark
@property
def pg_bin(self):
return self._pg.pg_bin
@@ -202,54 +154,12 @@ def zenith_compare(request, zenbenchmark, pg_bin, zenith_simple_env) -> ZenithCo
return ZenithCompare(zenbenchmark, zenith_simple_env, pg_bin, branch_name)
@pytest.fixture(scope='function')
def remote_compare(request, zenbenchmark, pg_bin, zenith_simple_env) -> Iterator[RemoteCompare]:
schema_name = "pytest_" + request.node.name.replace("[", "__").replace("]", "")
schema_identifier = f"\"{schema_name}\""
# NOTE password not specified, comes from pgpass file
stage_pg = PgProtocol(
host="start.stage.zenith.tech",
port=5432,
username="bojanserafimov@zenith",
dbname="main",
schema=schema_identifier,
)
# HACK for now testing with local zenith because proxy doesn't support schema option
# https://github.com/zenithdb/zenith/pull/1344
zenith_simple_env.zenith_cli.create_branch(schema_name, "empty")
zenith_pg = zenith_simple_env.postgres.create_start(schema_name)
pg = PgProtocol(
host=zenith_pg.host,
port=zenith_pg.port,
username=zenith_pg.username,
password=zenith_pg.password,
dbname=zenith_pg.dbname,
schema=schema_identifier,
)
pg.safe_psql(f"DROP SCHEMA IF EXISTS {schema_identifier} CASCADE;")
pg.safe_psql(f"CREATE SCHEMA {schema_identifier};")
yield RemoteCompare(zenbenchmark, pg, pg_bin)
pg.safe_psql(f"DROP SCHEMA IF EXISTS {schema_identifier} CASCADE;")
@pytest.fixture(scope='function')
def vanilla_compare(zenbenchmark, vanilla_pg) -> VanillaCompare:
return VanillaCompare(zenbenchmark, vanilla_pg)
@pytest.fixture(params=[
"vanilla_compare",
"zenith_compare",
"remote_compare",
],
ids=[
"vanilla",
"zenith",
"remote",
])
@pytest.fixture(params=["vanilla_compare", "zenith_compare"], ids=["vanilla", "zenith"])
def zenith_with_baseline(request) -> PgCompare:
"""Parameterized fixture that helps compare zenith against vanilla postgres.

View File

@@ -1,26 +0,0 @@
import pytest
"""
This plugin allows tests to be marked as slow using pytest.mark.slow. By default slow
tests are excluded. They need to be specifically requested with the --runslow flag in
order to run.
Copied from here: https://docs.pytest.org/en/latest/example/simple.html
"""
def pytest_addoption(parser):
parser.addoption("--runslow", action="store_true", default=False, help="run slow tests")
def pytest_configure(config):
config.addinivalue_line("markers", "slow: mark test as slow to run")
def pytest_collection_modifyitems(config, items):
if config.getoption("--runslow"):
# --runslow given in cli: do not skip slow tests
return
skip_slow = pytest.mark.skip(reason="need --runslow option to run")
for item in items:
if "slow" in item.keywords:
item.add_marker(skip_slow)

View File

@@ -1,7 +1,6 @@
from __future__ import annotations
from dataclasses import dataclass, field
import textwrap
from cached_property import cached_property
import asyncpg
import os
@@ -27,12 +26,11 @@ from dataclasses import dataclass
# Type-related stuff
from psycopg2.extensions import connection as PgConnection
from typing import Any, Callable, Dict, Iterable, Iterator, List, Optional, TypeVar, cast, Union, Tuple
from typing import Any, Callable, Dict, Iterator, List, Optional, TypeVar, cast, Union
from typing_extensions import Literal
import pytest
import requests
import backoff # type: ignore
from .utils import (get_self_dir, mkdir_if_needed, subprocess_capture)
from fixtures.log_helper import log
@@ -46,8 +44,9 @@ the standard pytest.fixture with some extra behavior.
There are several environment variables that can control the running of tests:
ZENITH_BIN, POSTGRES_DISTRIB_DIR, etc. See README.md for more information.
There's no need to import this file to use it. It should be declared as a plugin
inside conftest.py, and that makes it available to all tests.
To use fixtures in a test file, add this line of code:
>>> pytest_plugins = ("fixtures.zenith_fixtures")
Don't import functions from this file, or pytest will emit warnings. Instead
put directly-importable functions into utils.py or another separate file.
@@ -238,24 +237,14 @@ def port_distributor(worker_base_port):
class PgProtocol:
""" Reusable connection logic """
def __init__(self,
host: str,
port: int,
username: Optional[str] = None,
password: Optional[str] = None,
dbname: Optional[str] = None,
schema: Optional[str] = None):
def __init__(self, host: str, port: int, username: Optional[str] = None):
self.host = host
self.port = port
self.username = username
self.password = password
self.dbname = dbname
self.schema = schema
def connstr(self,
*,
dbname: Optional[str] = None,
schema: Optional[str] = None,
dbname: str = 'postgres',
username: Optional[str] = None,
password: Optional[str] = None) -> str:
"""
@@ -263,9 +252,6 @@ class PgProtocol:
"""
username = username or self.username
password = password or self.password
dbname = dbname or self.dbname or "postgres"
schema = schema or self.schema
res = f'host={self.host} port={self.port} dbname={dbname}'
if username:
@@ -274,17 +260,13 @@ class PgProtocol:
if password:
res = f'{res} password={password}'
if schema:
res = f"{res} options='-c search_path={self.schema}'"
return res
# autocommit=True here by default because that's what we need most of the time
def connect(self,
*,
autocommit=True,
dbname: Optional[str] = None,
schema: Optional[str] = None,
dbname: str = 'postgres',
username: Optional[str] = None,
password: Optional[str] = None) -> PgConnection:
"""
@@ -419,7 +401,6 @@ class ZenithEnvBuilder:
repo_dir: Path,
port_distributor: PortDistributor,
pageserver_remote_storage: Optional[RemoteStorage] = None,
pageserver_config_override: Optional[str] = None,
num_safekeepers: int = 0,
pageserver_auth_enabled: bool = False,
rust_log_override: Optional[str] = None):
@@ -427,7 +408,6 @@ class ZenithEnvBuilder:
self.rust_log_override = rust_log_override
self.port_distributor = port_distributor
self.pageserver_remote_storage = pageserver_remote_storage
self.pageserver_config_override = pageserver_config_override
self.num_safekeepers = num_safekeepers
self.pageserver_auth_enabled = pageserver_auth_enabled
self.env: Optional[ZenithEnv] = None
@@ -445,14 +425,6 @@ class ZenithEnvBuilder:
self.env = ZenithEnv(self)
return self.env
def start(self):
self.env.start()
def init_start(self) -> ZenithEnv:
env = self.init()
self.start()
return env
"""
Sets up the pageserver to use the local fs at the `test_dir/local_fs_remote_storage` path.
Errors, if the pageserver has some remote storage configuration already, unless `force_enable` is not set to `True`.
@@ -544,7 +516,6 @@ class ZenithEnv:
self.rust_log_override = config.rust_log_override
self.port_distributor = config.port_distributor
self.s3_mock_server = config.s3_mock_server
self.zenith_cli = ZenithCli(env=self)
self.postgres = PostgresFactory(self)
@@ -552,12 +523,12 @@ class ZenithEnv:
# generate initial tenant ID here instead of letting 'zenith init' generate it,
# so that we don't need to dig it out of the config file afterwards.
self.initial_tenant = uuid.uuid4()
self.initial_tenant = uuid.uuid4().hex
# Create a config file corresponding to the options
toml = textwrap.dedent(f"""
default_tenantid = '{self.initial_tenant.hex}'
""")
toml = f"""
default_tenantid = '{self.initial_tenant}'
"""
# Create config for pageserver
pageserver_port = PageserverPort(
@@ -566,19 +537,17 @@ class ZenithEnv:
)
pageserver_auth_type = "ZenithJWT" if config.pageserver_auth_enabled else "Trust"
toml += textwrap.dedent(f"""
[pageserver]
id=1
listen_pg_addr = 'localhost:{pageserver_port.pg}'
listen_http_addr = 'localhost:{pageserver_port.http}'
auth_type = '{pageserver_auth_type}'
""")
toml += f"""
[pageserver]
listen_pg_addr = 'localhost:{pageserver_port.pg}'
listen_http_addr = 'localhost:{pageserver_port.http}'
auth_type = '{pageserver_auth_type}'
"""
# Create a corresponding ZenithPageserver object
self.pageserver = ZenithPageserver(self,
port=pageserver_port,
remote_storage=config.pageserver_remote_storage,
config_override=config.pageserver_config_override)
remote_storage=config.pageserver_remote_storage)
# Create config and a Safekeeper object for each safekeeper
for i in range(1, config.num_safekeepers + 1):
@@ -586,22 +555,33 @@ class ZenithEnv:
pg=self.port_distributor.get_port(),
http=self.port_distributor.get_port(),
)
id = i # assign ids sequentially
toml += textwrap.dedent(f"""
[[safekeepers]]
id = {id}
pg_port = {port.pg}
http_port = {port.http}
sync = false # Disable fsyncs to make the tests go faster
""")
safekeeper = Safekeeper(env=self, id=id, port=port)
if config.num_safekeepers == 1:
name = "single"
else:
name = f"sk{i}"
toml += f"""
[[safekeepers]]
name = '{name}'
pg_port = {port.pg}
http_port = {port.http}
sync = false # Disable fsyncs to make the tests go faster
"""
safekeeper = Safekeeper(env=self, name=name, port=port)
self.safekeepers.append(safekeeper)
log.info(f"Config: {toml}")
self.zenith_cli.init(toml)
# Run 'zenith init' using the config file we constructed
with tempfile.NamedTemporaryFile(mode='w+') as tmp:
tmp.write(toml)
tmp.flush()
cmd = ['init', f'--config={tmp.name}']
append_pageserver_param_overrides(cmd, config.pageserver_remote_storage)
self.zenith_cli(cmd)
def start(self):
# Start up the page server and all the safekeepers
self.pageserver.start()
@@ -612,12 +592,69 @@ class ZenithEnv:
""" Get list of safekeeper endpoints suitable for wal_acceptors GUC """
return ','.join([f'localhost:{wa.port.pg}' for wa in self.safekeepers])
def create_tenant(self, tenant_id: Optional[uuid.UUID] = None) -> uuid.UUID:
def create_tenant(self, tenant_id: Optional[str] = None):
if tenant_id is None:
tenant_id = uuid.uuid4()
self.zenith_cli.create_tenant(tenant_id)
tenant_id = uuid.uuid4().hex
res = self.zenith_cli(['tenant', 'create', tenant_id])
res.check_returncode()
return tenant_id
def zenith_cli(self, arguments: List[str]) -> 'subprocess.CompletedProcess[str]':
"""
Run "zenith" with the specified arguments.
Arguments must be in list form, e.g. ['pg', 'create']
Return both stdout and stderr, which can be accessed as
>>> result = env.zenith_cli(...)
>>> assert result.stderr == ""
>>> log.info(result.stdout)
"""
assert type(arguments) == list
bin_zenith = os.path.join(str(zenith_binpath), 'zenith')
args = [bin_zenith] + arguments
log.info('Running command "{}"'.format(' '.join(args)))
log.info(f'Running in "{self.repo_dir}"')
env_vars = os.environ.copy()
env_vars['ZENITH_REPO_DIR'] = str(self.repo_dir)
env_vars['POSTGRES_DISTRIB_DIR'] = str(pg_distrib_dir)
if self.rust_log_override is not None:
env_vars['RUST_LOG'] = self.rust_log_override
# Pass coverage settings
var = 'LLVM_PROFILE_FILE'
val = os.environ.get(var)
if val:
env_vars[var] = val
# Intercept CalledProcessError and print more info
try:
res = subprocess.run(args,
env=env_vars,
check=True,
universal_newlines=True,
stdout=subprocess.PIPE,
stderr=subprocess.PIPE)
log.info(f"Run success: {res.stdout}")
except subprocess.CalledProcessError as exc:
# this way command output will be in recorded and shown in CI in failure message
msg = f"""\
Run failed: {exc}
stdout: {exc.stdout}
stderr: {exc.stderr}
"""
log.info(msg)
raise Exception(msg) from exc
return res
@cached_property
def auth_keys(self) -> AuthKeys:
pub = (Path(self.repo_dir) / 'auth_public_key.pem').read_bytes()
@@ -642,10 +679,10 @@ def _shared_simple_env(request: Any, port_distributor) -> Iterator[ZenithEnv]:
with ZenithEnvBuilder(Path(repo_dir), port_distributor) as builder:
env = builder.init_start()
env = builder.init()
# For convenience in tests, create a branch from the freshly-initialized cluster.
env.zenith_cli.create_branch("empty", "main")
env.zenith_cli(["branch", "empty", "main"])
# Return the builder to the caller
yield env
@@ -676,7 +713,7 @@ def zenith_env_builder(test_output_dir, port_distributor) -> Iterator[ZenithEnvB
To use, define 'zenith_env_builder' fixture in your test to get access to the
builder object. Set properties on it to describe the environment.
Finally, initialize and start up the environment by calling
zenith_env_builder.init_start().
zenith_env_builder.init().
After the initialization, you can launch compute nodes by calling
the functions in the 'env.postgres' factory object, stop/start the
@@ -691,10 +728,6 @@ def zenith_env_builder(test_output_dir, port_distributor) -> Iterator[ZenithEnvB
yield builder
class ZenithPageserverApiException(Exception):
pass
class ZenithPageserverHttpClient(requests.Session):
def __init__(self, port: int, auth_token: Optional[str] = None) -> None:
super().__init__()
@@ -704,32 +737,22 @@ class ZenithPageserverHttpClient(requests.Session):
if auth_token is not None:
self.headers['Authorization'] = f'Bearer {auth_token}'
def verbose_error(self, res: requests.Response):
try:
res.raise_for_status()
except requests.RequestException as e:
try:
msg = res.json()['msg']
except:
msg = ''
raise ZenithPageserverApiException(msg) from e
def check_status(self):
self.get(f"http://localhost:{self.port}/v1/status").raise_for_status()
def timeline_attach(self, tenant_id: uuid.UUID, timeline_id: uuid.UUID):
res = self.post(
f"http://localhost:{self.port}/v1/timeline/{tenant_id.hex}/{timeline_id.hex}/attach", )
self.verbose_error(res)
res.raise_for_status()
def timeline_detach(self, tenant_id: uuid.UUID, timeline_id: uuid.UUID):
res = self.post(
f"http://localhost:{self.port}/v1/timeline/{tenant_id.hex}/{timeline_id.hex}/detach", )
self.verbose_error(res)
res.raise_for_status()
def branch_list(self, tenant_id: uuid.UUID) -> List[Dict[Any, Any]]:
res = self.get(f"http://localhost:{self.port}/v1/branch/{tenant_id.hex}")
self.verbose_error(res)
res.raise_for_status()
res_json = res.json()
assert isinstance(res_json, list)
return res_json
@@ -741,7 +764,7 @@ class ZenithPageserverHttpClient(requests.Session):
'name': name,
'start_point': start_point,
})
self.verbose_error(res)
res.raise_for_status()
res_json = res.json()
assert isinstance(res_json, dict)
return res_json
@@ -750,14 +773,14 @@ class ZenithPageserverHttpClient(requests.Session):
res = self.get(
f"http://localhost:{self.port}/v1/branch/{tenant_id.hex}/{name}?include-non-incremental-logical-size=1",
)
self.verbose_error(res)
res.raise_for_status()
res_json = res.json()
assert isinstance(res_json, dict)
return res_json
def tenant_list(self) -> List[Dict[Any, Any]]:
res = self.get(f"http://localhost:{self.port}/v1/tenant")
self.verbose_error(res)
res.raise_for_status()
res_json = res.json()
assert isinstance(res_json, list)
return res_json
@@ -769,27 +792,27 @@ class ZenithPageserverHttpClient(requests.Session):
'tenant_id': tenant_id.hex,
},
)
self.verbose_error(res)
res.raise_for_status()
return res.json()
def timeline_list(self, tenant_id: uuid.UUID) -> List[str]:
res = self.get(f"http://localhost:{self.port}/v1/timeline/{tenant_id.hex}")
self.verbose_error(res)
res.raise_for_status()
res_json = res.json()
assert isinstance(res_json, list)
return res_json
def timeline_detail(self, tenant_id: uuid.UUID, timeline_id: uuid.UUID) -> Dict[Any, Any]:
def timeline_detail(self, tenant_id: uuid.UUID, timeline_id: uuid.UUID):
res = self.get(
f"http://localhost:{self.port}/v1/timeline/{tenant_id.hex}/{timeline_id.hex}")
self.verbose_error(res)
res.raise_for_status()
res_json = res.json()
assert isinstance(res_json, dict)
return res_json
def get_metrics(self) -> str:
res = self.get(f"http://localhost:{self.port}/metrics")
self.verbose_error(res)
res.raise_for_status()
return res.text
@@ -816,193 +839,6 @@ class S3Storage:
RemoteStorage = Union[LocalFsStorage, S3Storage]
class ZenithCli:
"""
A typed wrapper around the `zenith` CLI tool.
Supports main commands via typed methods and a way to run arbitrary command directly via CLI.
"""
def __init__(self, env: ZenithEnv) -> None:
self.env = env
pass
def create_tenant(self, tenant_id: Optional[uuid.UUID] = None) -> uuid.UUID:
if tenant_id is None:
tenant_id = uuid.uuid4()
self.raw_cli(['tenant', 'create', tenant_id.hex])
return tenant_id
def list_tenants(self) -> 'subprocess.CompletedProcess[str]':
return self.raw_cli(['tenant', 'list'])
def create_branch(self,
branch_name: str,
starting_point: str,
tenant_id: Optional[uuid.UUID] = None) -> 'subprocess.CompletedProcess[str]':
args = ['branch']
if tenant_id is not None:
args.extend(['--tenantid', tenant_id.hex])
args.extend([branch_name, starting_point])
return self.raw_cli(args)
def list_branches(self,
tenant_id: Optional[uuid.UUID] = None) -> 'subprocess.CompletedProcess[str]':
args = ['branch']
if tenant_id is not None:
args.extend(['--tenantid', tenant_id.hex])
return self.raw_cli(args)
def init(self, config_toml: str) -> 'subprocess.CompletedProcess[str]':
with tempfile.NamedTemporaryFile(mode='w+') as tmp:
tmp.write(config_toml)
tmp.flush()
cmd = ['init', f'--config={tmp.name}']
append_pageserver_param_overrides(cmd,
self.env.pageserver.remote_storage,
self.env.pageserver.config_override)
return self.raw_cli(cmd)
def pageserver_start(self, overrides=()) -> 'subprocess.CompletedProcess[str]':
start_args = ['pageserver', 'start', *overrides]
append_pageserver_param_overrides(start_args,
self.env.pageserver.remote_storage,
self.env.pageserver.config_override)
return self.raw_cli(start_args)
def pageserver_stop(self, immediate=False) -> 'subprocess.CompletedProcess[str]':
cmd = ['pageserver', 'stop']
if immediate:
cmd.extend(['-m', 'immediate'])
log.info(f"Stopping pageserver with {cmd}")
return self.raw_cli(cmd)
def safekeeper_start(self, id: int) -> 'subprocess.CompletedProcess[str]':
return self.raw_cli(['safekeeper', 'start', str(id)])
def safekeeper_stop(self,
id: Optional[int] = None,
immediate=False) -> 'subprocess.CompletedProcess[str]':
args = ['safekeeper', 'stop']
if id is not None:
args.extend(str(id))
if immediate:
args.extend(['-m', 'immediate'])
return self.raw_cli(args)
def pg_create(
self,
node_name: str,
tenant_id: Optional[uuid.UUID] = None,
timeline_spec: Optional[str] = None,
port: Optional[int] = None,
) -> 'subprocess.CompletedProcess[str]':
args = ['pg', 'create']
if tenant_id is not None:
args.extend(['--tenantid', tenant_id.hex])
if port is not None:
args.append(f'--port={port}')
args.append(node_name)
if timeline_spec is not None:
args.append(timeline_spec)
return self.raw_cli(args)
def pg_start(
self,
node_name: str,
tenant_id: Optional[uuid.UUID] = None,
timeline_spec: Optional[str] = None,
port: Optional[int] = None,
) -> 'subprocess.CompletedProcess[str]':
args = ['pg', 'start']
if tenant_id is not None:
args.extend(['--tenantid', tenant_id.hex])
if port is not None:
args.append(f'--port={port}')
args.append(node_name)
if timeline_spec is not None:
args.append(timeline_spec)
return self.raw_cli(args)
def pg_stop(
self,
node_name: str,
tenant_id: Optional[uuid.UUID] = None,
destroy=False,
) -> 'subprocess.CompletedProcess[str]':
args = ['pg', 'stop']
if tenant_id is not None:
args.extend(['--tenantid', tenant_id.hex])
if destroy:
args.append('--destroy')
args.append(node_name)
return self.raw_cli(args)
def raw_cli(self,
arguments: List[str],
check_return_code=True) -> 'subprocess.CompletedProcess[str]':
"""
Run "zenith" with the specified arguments.
Arguments must be in list form, e.g. ['pg', 'create']
Return both stdout and stderr, which can be accessed as
>>> result = env.zenith_cli.raw_cli(...)
>>> assert result.stderr == ""
>>> log.info(result.stdout)
"""
assert type(arguments) == list
bin_zenith = os.path.join(str(zenith_binpath), 'zenith')
args = [bin_zenith] + arguments
log.info('Running command "{}"'.format(' '.join(args)))
log.info(f'Running in "{self.env.repo_dir}"')
env_vars = os.environ.copy()
env_vars['ZENITH_REPO_DIR'] = str(self.env.repo_dir)
env_vars['POSTGRES_DISTRIB_DIR'] = str(pg_distrib_dir)
if self.env.rust_log_override is not None:
env_vars['RUST_LOG'] = self.env.rust_log_override
# Pass coverage settings
var = 'LLVM_PROFILE_FILE'
val = os.environ.get(var)
if val:
env_vars[var] = val
# Intercept CalledProcessError and print more info
try:
res = subprocess.run(args,
env=env_vars,
check=True,
universal_newlines=True,
stdout=subprocess.PIPE,
stderr=subprocess.PIPE)
log.info(f"Run success: {res.stdout}")
except subprocess.CalledProcessError as exc:
# this way command output will be in recorded and shown in CI in failure message
msg = f"""\
Run failed: {exc}
stdout: {exc.stdout}
stderr: {exc.stderr}
"""
log.info(msg)
raise Exception(msg) from exc
if check_return_code:
res.check_returncode()
return res
class ZenithPageserver(PgProtocol):
"""
An object representing a running pageserver.
@@ -1013,24 +849,24 @@ class ZenithPageserver(PgProtocol):
env: ZenithEnv,
port: PageserverPort,
remote_storage: Optional[RemoteStorage] = None,
config_override: Optional[str] = None,
enable_auth=False):
super().__init__(host='localhost', port=port.pg, username='zenith_admin')
self.env = env
self.running = False
self.service_port = port # do not shadow PgProtocol.port which is just int
self.remote_storage = remote_storage
self.config_override = config_override
def start(self, overrides=()) -> 'ZenithPageserver':
def start(self) -> 'ZenithPageserver':
"""
Start the page server.
`overrides` allows to add some config to this pageserver start.
Returns self.
"""
assert self.running == False
self.env.zenith_cli.pageserver_start(overrides=overrides)
start_args = ['pageserver', 'start']
append_pageserver_param_overrides(start_args, self.remote_storage)
self.env.zenith_cli(start_args)
self.running = True
return self
@@ -1039,8 +875,13 @@ class ZenithPageserver(PgProtocol):
Stop the page server.
Returns self.
"""
cmd = ['pageserver', 'stop']
if immediate:
cmd.extend(['-m', 'immediate'])
log.info(f"Stopping pageserver with {cmd}")
if self.running:
self.env.zenith_cli.pageserver_stop(immediate)
self.env.zenith_cli(cmd)
self.running = False
return self
@@ -1058,11 +899,8 @@ class ZenithPageserver(PgProtocol):
)
def append_pageserver_param_overrides(
params_to_update: List[str],
pageserver_remote_storage: Optional[RemoteStorage],
pageserver_config_override: Optional[str] = None,
):
def append_pageserver_param_overrides(params_to_update: List[str],
pageserver_remote_storage: Optional[RemoteStorage]):
if pageserver_remote_storage is not None:
if isinstance(pageserver_remote_storage, LocalFsStorage):
pageserver_storage_override = f"local_path='{pageserver_remote_storage.root}'"
@@ -1088,12 +926,6 @@ def append_pageserver_param_overrides(
f'--pageserver-config-override={o.strip()}' for o in env_overrides.split(';')
]
if pageserver_config_override is not None:
params_to_update += [
f'--pageserver-config-override={o.strip()}'
for o in pageserver_config_override.split(';')
]
class PgBin:
""" A helper class for executing postgres binaries """
@@ -1200,53 +1032,9 @@ def vanilla_pg(test_output_dir: str) -> Iterator[VanillaPostgres]:
yield vanilla_pg
class ZenithProxy(PgProtocol):
def __init__(self, port: int):
super().__init__(host="127.0.0.1", username="pytest", password="pytest", port=port)
self.http_port = 7001
self._popen: Optional[subprocess.Popen[bytes]] = None
def start_static(self, addr="127.0.0.1:5432") -> None:
assert self._popen is None
# Start proxy
bin_proxy = os.path.join(str(zenith_binpath), 'proxy')
args = [bin_proxy]
args.extend(["--http", f"{self.host}:{self.http_port}"])
args.extend(["--proxy", f"{self.host}:{self.port}"])
args.extend(["--auth-method", "password"])
args.extend(["--static-router", addr])
self._popen = subprocess.Popen(args)
self._wait_until_ready()
@backoff.on_exception(backoff.expo, requests.exceptions.RequestException, max_time=10)
def _wait_until_ready(self):
requests.get(f"http://{self.host}:{self.http_port}/v1/status")
def __enter__(self):
return self
def __exit__(self, exc_type, exc, tb):
if self._popen is not None:
# NOTE the process will die when we're done with tests anyway, because
# it's a child process. This is mostly to clean up in between different tests.
self._popen.kill()
@pytest.fixture(scope='function')
def static_proxy(vanilla_pg) -> Iterator[ZenithProxy]:
"""Zenith proxy that routes directly to vanilla postgres."""
vanilla_pg.start()
vanilla_pg.safe_psql("create user pytest with password 'pytest';")
with ZenithProxy(4432) as proxy:
proxy.start_static()
yield proxy
class Postgres(PgProtocol):
""" An object representing a running postgres daemon. """
def __init__(self, env: ZenithEnv, tenant_id: uuid.UUID, port: int):
def __init__(self, env: ZenithEnv, tenant_id: str, port: int):
super().__init__(host='localhost', port=port, username='zenith_admin')
self.env = env
@@ -1273,12 +1061,16 @@ class Postgres(PgProtocol):
if branch is None:
branch = node_name
self.env.zenith_cli.pg_create(node_name,
tenant_id=self.tenant_id,
port=self.port,
timeline_spec=branch)
self.env.zenith_cli([
'pg',
'create',
f'--tenantid={self.tenant_id}',
f'--port={self.port}',
node_name,
branch
])
self.node_name = node_name
path = pathlib.Path('pgdatadirs') / 'tenants' / self.tenant_id.hex / self.node_name
path = pathlib.Path('pgdatadirs') / 'tenants' / self.tenant_id / self.node_name
self.pgdata_dir = os.path.join(self.env.repo_dir, path)
if config_lines is None:
@@ -1297,9 +1089,8 @@ class Postgres(PgProtocol):
log.info(f"Starting postgres node {self.node_name}")
run_result = self.env.zenith_cli.pg_start(self.node_name,
tenant_id=self.tenant_id,
port=self.port)
run_result = self.env.zenith_cli(
['pg', 'start', f'--tenantid={self.tenant_id}', f'--port={self.port}', self.node_name])
self.running = True
log.info(f"stdout: {run_result.stdout}")
@@ -1309,7 +1100,7 @@ class Postgres(PgProtocol):
def pg_data_dir_path(self) -> str:
""" Path to data directory """
assert self.node_name
path = pathlib.Path('pgdatadirs') / 'tenants' / self.tenant_id.hex / self.node_name
path = pathlib.Path('pgdatadirs') / 'tenants' / self.tenant_id / self.node_name
return os.path.join(self.env.repo_dir, path)
def pg_xact_dir_path(self) -> str:
@@ -1369,7 +1160,7 @@ class Postgres(PgProtocol):
if self.running:
assert self.node_name is not None
self.env.zenith_cli.pg_stop(self.node_name, tenant_id=self.tenant_id)
self.env.zenith_cli(['pg', 'stop', self.node_name, f'--tenantid={self.tenant_id}'])
self.running = False
return self
@@ -1381,7 +1172,8 @@ class Postgres(PgProtocol):
"""
assert self.node_name is not None
self.env.zenith_cli.pg_stop(self.node_name, self.tenant_id, destroy=True)
self.env.zenith_cli(
['pg', 'stop', '--destroy', self.node_name, f'--tenantid={self.tenant_id}'])
self.node_name = None
return self
@@ -1423,7 +1215,7 @@ class PostgresFactory:
def create_start(self,
node_name: str = "main",
branch: Optional[str] = None,
tenant_id: Optional[uuid.UUID] = None,
tenant_id: Optional[str] = None,
config_lines: Optional[List[str]] = None) -> Postgres:
pg = Postgres(
@@ -1443,7 +1235,7 @@ class PostgresFactory:
def create(self,
node_name: str = "main",
branch: Optional[str] = None,
tenant_id: Optional[uuid.UUID] = None,
tenant_id: Optional[str] = None,
config_lines: Optional[List[str]] = None) -> Postgres:
pg = Postgres(
@@ -1484,14 +1276,12 @@ class Safekeeper:
""" An object representing a running safekeeper daemon. """
env: ZenithEnv
port: SafekeeperPort
id: int
name: str # identifier for logging
auth_token: Optional[str] = None
running: bool = False
def start(self) -> 'Safekeeper':
assert self.running == False
self.env.zenith_cli.safekeeper_start(self.id)
self.running = True
self.env.zenith_cli(['safekeeper', 'start', self.name])
# wait for wal acceptor start by checking its status
started_at = time.time()
while True:
@@ -1509,14 +1299,16 @@ class Safekeeper:
return self
def stop(self, immediate=False) -> 'Safekeeper':
log.info('Stopping safekeeper {}'.format(self.id))
self.env.zenith_cli.safekeeper_stop(self.id, immediate)
self.running = False
cmd = ['safekeeper', 'stop']
if immediate:
cmd.extend(['-m', 'immediate'])
cmd.append(self.name)
log.info('Stopping safekeeper {}'.format(self.name))
self.env.zenith_cli(cmd)
return self
def append_logical_message(self,
tenant_id: uuid.UUID,
timeline_id: uuid.UUID,
def append_logical_message(self, tenant_id: str, timeline_id: str,
request: Dict[str, Any]) -> Dict[str, Any]:
"""
Send JSON_CTRL query to append LogicalMessage to WAL and modify
@@ -1526,7 +1318,7 @@ class Safekeeper:
# "replication=0" hacks psycopg not to send additional queries
# on startup, see https://github.com/psycopg/psycopg2/pull/482
connstr = f"host=localhost port={self.port.pg} replication=0 options='-c ztimelineid={timeline_id.hex} ztenantid={tenant_id.hex}'"
connstr = f"host=localhost port={self.port.pg} replication=0 options='-c ztimelineid={timeline_id} ztenantid={tenant_id}'"
with closing(psycopg2.connect(connstr)) as conn:
# server doesn't support transactions
@@ -1555,8 +1347,8 @@ class SafekeeperTimelineStatus:
class SafekeeperMetrics:
# These are metrics from Prometheus which uses float64 internally.
# As a consequence, values may differ from real original int64s.
flush_lsn_inexact: Dict[Tuple[str, str], int] = field(default_factory=dict)
commit_lsn_inexact: Dict[Tuple[str, str], int] = field(default_factory=dict)
flush_lsn_inexact: Dict[str, int] = field(default_factory=dict)
commit_lsn_inexact: Dict[str, int] = field(default_factory=dict)
class SafekeeperHttpClient(requests.Session):
@@ -1580,16 +1372,14 @@ class SafekeeperHttpClient(requests.Session):
all_metrics_text = request_result.text
metrics = SafekeeperMetrics()
for match in re.finditer(
r'^safekeeper_flush_lsn{tenant_id="([0-9a-f]+)",timeline_id="([0-9a-f]+)"} (\S+)$',
all_metrics_text,
re.MULTILINE):
metrics.flush_lsn_inexact[(match.group(1), match.group(2))] = int(match.group(3))
for match in re.finditer(
r'^safekeeper_commit_lsn{tenant_id="([0-9a-f]+)",timeline_id="([0-9a-f]+)"} (\S+)$',
all_metrics_text,
re.MULTILINE):
metrics.commit_lsn_inexact[(match.group(1), match.group(2))] = int(match.group(3))
for match in re.finditer(r'^safekeeper_flush_lsn{ztli="([0-9a-f]+)"} (\S+)$',
all_metrics_text,
re.MULTILINE):
metrics.flush_lsn_inexact[match.group(1)] = int(match.group(2))
for match in re.finditer(r'^safekeeper_commit_lsn{ztli="([0-9a-f]+)"} (\S+)$',
all_metrics_text,
re.MULTILINE):
metrics.commit_lsn_inexact[match.group(1)] = int(match.group(2))
return metrics
@@ -1698,7 +1488,7 @@ def check_restored_datadir_content(test_output_dir: str, env: ZenithEnv, pg: Pos
{psql_path} \
--no-psqlrc \
postgres://localhost:{env.pageserver.service_port.pg} \
-c 'basebackup {pg.tenant_id.hex} {timeline}' \
-c 'basebackup {pg.tenant_id} {timeline}' \
| tar -x -C {restored_dir_path}
"""

View File

@@ -4,6 +4,12 @@ from fixtures.log_helper import log
from fixtures.benchmark_fixture import MetricReport, ZenithBenchmarker
from fixtures.compare_fixtures import PgCompare, VanillaCompare, ZenithCompare
pytest_plugins = (
"fixtures.zenith_fixtures",
"fixtures.benchmark_fixture",
"fixtures.compare_fixtures",
)
#
# Run bulk INSERT test.

View File

@@ -4,6 +4,8 @@ import pytest
from fixtures.zenith_fixtures import ZenithEnvBuilder
pytest_plugins = ("fixtures.benchmark_fixture")
# Run bulk tenant creation test.
#
# Collects metrics:
@@ -23,7 +25,7 @@ def test_bulk_tenant_create(
"""Measure tenant creation time (with and without wal acceptors)"""
if use_wal_acceptors == 'with_wa':
zenith_env_builder.num_safekeepers = 3
env = zenith_env_builder.init_start()
env = zenith_env_builder.init()
time_slices = []
@@ -31,10 +33,12 @@ def test_bulk_tenant_create(
start = timeit.default_timer()
tenant = env.create_tenant()
env.zenith_cli.create_branch(
env.zenith_cli([
"branch",
f"test_bulk_tenant_create_{tenants_count}_{i}_{use_wal_acceptors}",
"main",
tenant_id=tenant)
f"--tenantid={tenant}"
])
# FIXME: We used to start new safekeepers here. Did that make sense? Should we do it now?
#if use_wal_acceptors == 'with_wa':

Some files were not shown because too many files have changed in this diff Show More