mirror of
https://github.com/neondatabase/neon.git
synced 2026-05-16 20:50:37 +00:00
Compare commits
31 Commits
refactor-l
...
asher/sk-w
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
f281dc5953 | ||
|
|
48fb085ebd | ||
|
|
2bbd24edbf | ||
|
|
5e972ccdc4 | ||
|
|
95bf19b85a | ||
|
|
80d4afab0c | ||
|
|
0807522a64 | ||
|
|
8eebd5f039 | ||
|
|
8c07ef413d | ||
|
|
14df37c108 | ||
|
|
d4d0aa6ed6 | ||
|
|
a457256fef | ||
|
|
3a22e1335d | ||
|
|
93c77b0383 | ||
|
|
7920b39a27 | ||
|
|
23d5e2bdaa | ||
|
|
3526323bc4 | ||
|
|
af9425394f | ||
|
|
debd134b15 | ||
|
|
df42213dbb | ||
|
|
b6237474d2 | ||
|
|
8b710b9753 | ||
|
|
c187de1101 | ||
|
|
8712e1899e | ||
|
|
d7f1e30112 | ||
|
|
6a9d1030a6 | ||
|
|
8c6e607327 | ||
|
|
f436fb2dfb | ||
|
|
8932d14d50 | ||
|
|
efad64bc7f | ||
|
|
10dae79c6d |
10
.github/PULL_REQUEST_TEMPLATE/pull_request_template.md
vendored
Normal file
10
.github/PULL_REQUEST_TEMPLATE/pull_request_template.md
vendored
Normal file
@@ -0,0 +1,10 @@
|
||||
## Describe your changes
|
||||
|
||||
## Issue ticket number and link
|
||||
|
||||
## Checklist before requesting a review
|
||||
- [ ] I have performed a self-review of my code.
|
||||
- [ ] If it is a core feature, I have added thorough tests.
|
||||
- [ ] Do we need to implement analytics? if so did you add the relevant metrics to the dashboard?
|
||||
- [ ] If this PR requires public announcement, mark it with /release-notes label and add several sentences in this section.
|
||||
|
||||
@@ -123,8 +123,8 @@ runs:
|
||||
exit 1
|
||||
fi
|
||||
if [[ "${{ inputs.run_in_parallel }}" == "true" ]]; then
|
||||
# -n8 uses eight processes to run tests via pytest-xdist
|
||||
EXTRA_PARAMS="-n8 $EXTRA_PARAMS"
|
||||
# -n4 uses four processes to run tests via pytest-xdist
|
||||
EXTRA_PARAMS="-n4 $EXTRA_PARAMS"
|
||||
|
||||
# --dist=loadgroup points tests marked with @pytest.mark.xdist_group
|
||||
# to the same worker to make @pytest.mark.order work with xdist
|
||||
|
||||
@@ -9,6 +9,7 @@ settings:
|
||||
authEndpoint: "http://console-staging.local/management/api/v2"
|
||||
domain: "*.eu-west-1.aws.neon.build"
|
||||
sentryEnvironment: "development"
|
||||
wssPort: 8443
|
||||
|
||||
# -- Additional labels for neon-proxy pods
|
||||
podLabels:
|
||||
@@ -23,6 +24,7 @@ exposedService:
|
||||
service.beta.kubernetes.io/aws-load-balancer-nlb-target-type: ip
|
||||
service.beta.kubernetes.io/aws-load-balancer-scheme: internet-facing
|
||||
external-dns.alpha.kubernetes.io/hostname: eu-west-1.aws.neon.build
|
||||
httpsPort: 443
|
||||
|
||||
#metrics:
|
||||
# enabled: true
|
||||
|
||||
@@ -9,6 +9,7 @@ settings:
|
||||
authEndpoint: "http://console-staging.local/management/api/v2"
|
||||
domain: "*.cloud.stage.neon.tech"
|
||||
sentryEnvironment: "development"
|
||||
wssPort: 8443
|
||||
|
||||
# -- Additional labels for neon-proxy pods
|
||||
podLabels:
|
||||
@@ -23,6 +24,7 @@ exposedService:
|
||||
service.beta.kubernetes.io/aws-load-balancer-nlb-target-type: ip
|
||||
service.beta.kubernetes.io/aws-load-balancer-scheme: internet-facing
|
||||
external-dns.alpha.kubernetes.io/hostname: neon-proxy-scram-legacy.beta.us-east-2.aws.neon.build
|
||||
httpsPort: 443
|
||||
|
||||
#metrics:
|
||||
# enabled: true
|
||||
|
||||
@@ -9,6 +9,7 @@ settings:
|
||||
authEndpoint: "http://console-staging.local/management/api/v2"
|
||||
domain: "*.us-east-2.aws.neon.build"
|
||||
sentryEnvironment: "development"
|
||||
wssPort: 8443
|
||||
|
||||
# -- Additional labels for neon-proxy pods
|
||||
podLabels:
|
||||
@@ -23,6 +24,7 @@ exposedService:
|
||||
service.beta.kubernetes.io/aws-load-balancer-nlb-target-type: ip
|
||||
service.beta.kubernetes.io/aws-load-balancer-scheme: internet-facing
|
||||
external-dns.alpha.kubernetes.io/hostname: us-east-2.aws.neon.build
|
||||
httpsPort: 443
|
||||
|
||||
#metrics:
|
||||
# enabled: true
|
||||
|
||||
@@ -9,6 +9,7 @@ settings:
|
||||
authEndpoint: "http://console-release.local/management/api/v2"
|
||||
domain: "*.ap-southeast-1.aws.neon.tech"
|
||||
sentryEnvironment: "production"
|
||||
wssPort: 8443
|
||||
|
||||
# -- Additional labels for neon-proxy pods
|
||||
podLabels:
|
||||
@@ -23,6 +24,7 @@ exposedService:
|
||||
service.beta.kubernetes.io/aws-load-balancer-nlb-target-type: ip
|
||||
service.beta.kubernetes.io/aws-load-balancer-scheme: internet-facing
|
||||
external-dns.alpha.kubernetes.io/hostname: ap-southeast-1.aws.neon.tech
|
||||
httpsPort: 443
|
||||
|
||||
#metrics:
|
||||
# enabled: true
|
||||
|
||||
@@ -9,6 +9,7 @@ settings:
|
||||
authEndpoint: "http://console-release.local/management/api/v2"
|
||||
domain: "*.eu-central-1.aws.neon.tech"
|
||||
sentryEnvironment: "production"
|
||||
wssPort: 8443
|
||||
|
||||
# -- Additional labels for neon-proxy pods
|
||||
podLabels:
|
||||
@@ -23,6 +24,7 @@ exposedService:
|
||||
service.beta.kubernetes.io/aws-load-balancer-nlb-target-type: ip
|
||||
service.beta.kubernetes.io/aws-load-balancer-scheme: internet-facing
|
||||
external-dns.alpha.kubernetes.io/hostname: eu-central-1.aws.neon.tech
|
||||
httpsPort: 443
|
||||
|
||||
#metrics:
|
||||
# enabled: true
|
||||
|
||||
@@ -9,6 +9,7 @@ settings:
|
||||
authEndpoint: "http://console-release.local/management/api/v2"
|
||||
domain: "*.us-east-2.aws.neon.tech"
|
||||
sentryEnvironment: "production"
|
||||
wssPort: 8443
|
||||
|
||||
# -- Additional labels for neon-proxy pods
|
||||
podLabels:
|
||||
@@ -23,6 +24,7 @@ exposedService:
|
||||
service.beta.kubernetes.io/aws-load-balancer-nlb-target-type: ip
|
||||
service.beta.kubernetes.io/aws-load-balancer-scheme: internet-facing
|
||||
external-dns.alpha.kubernetes.io/hostname: us-east-2.aws.neon.tech
|
||||
httpsPort: 443
|
||||
|
||||
#metrics:
|
||||
# enabled: true
|
||||
|
||||
@@ -9,6 +9,7 @@ settings:
|
||||
authEndpoint: "http://console-release.local/management/api/v2"
|
||||
domain: "*.us-west-2.aws.neon.tech"
|
||||
sentryEnvironment: "production"
|
||||
wssPort: 8443
|
||||
|
||||
# -- Additional labels for neon-proxy pods
|
||||
podLabels:
|
||||
@@ -23,6 +24,7 @@ exposedService:
|
||||
service.beta.kubernetes.io/aws-load-balancer-nlb-target-type: ip
|
||||
service.beta.kubernetes.io/aws-load-balancer-scheme: internet-facing
|
||||
external-dns.alpha.kubernetes.io/hostname: us-west-2.aws.neon.tech
|
||||
httpsPort: 443
|
||||
|
||||
#metrics:
|
||||
# enabled: true
|
||||
|
||||
@@ -3,6 +3,7 @@ settings:
|
||||
authEndpoint: "http://console-release.local/management/api/v2"
|
||||
domain: "*.cloud.neon.tech"
|
||||
sentryEnvironment: "production"
|
||||
wssPort: 8443
|
||||
|
||||
podLabels:
|
||||
zenith_service: proxy-scram
|
||||
@@ -16,6 +17,7 @@ exposedService:
|
||||
service.beta.kubernetes.io/aws-load-balancer-nlb-target-type: ip
|
||||
service.beta.kubernetes.io/aws-load-balancer-scheme: internet-facing
|
||||
external-dns.alpha.kubernetes.io/hostname: '*.cloud.neon.tech'
|
||||
httpsPort: 443
|
||||
|
||||
metrics:
|
||||
enabled: true
|
||||
|
||||
32
.github/workflows/build_and_test.yml
vendored
32
.github/workflows/build_and_test.yml
vendored
@@ -794,6 +794,8 @@ jobs:
|
||||
strategy:
|
||||
matrix:
|
||||
include: ${{fromJSON(needs.calculate-deploy-targets.outputs.matrix-include)}}
|
||||
environment:
|
||||
name: prod-old
|
||||
steps:
|
||||
- name: Checkout
|
||||
uses: actions/checkout@v3
|
||||
@@ -839,7 +841,9 @@ jobs:
|
||||
shell: bash
|
||||
strategy:
|
||||
matrix:
|
||||
target_region: [ us-east-2 ]
|
||||
target_region: [ eu-west-1, us-east-2 ]
|
||||
environment:
|
||||
name: dev-${{ matrix.target_region }}
|
||||
steps:
|
||||
- name: Checkout
|
||||
uses: actions/checkout@v3
|
||||
@@ -911,6 +915,8 @@ jobs:
|
||||
strategy:
|
||||
matrix:
|
||||
target_region: [ us-east-2, us-west-2, eu-central-1, ap-southeast-1 ]
|
||||
environment:
|
||||
name: prod-${{ matrix.target_region }}
|
||||
steps:
|
||||
- name: Checkout
|
||||
uses: actions/checkout@v3
|
||||
@@ -950,6 +956,8 @@ jobs:
|
||||
strategy:
|
||||
matrix:
|
||||
include: ${{fromJSON(needs.calculate-deploy-targets.outputs.matrix-include)}}
|
||||
environment:
|
||||
name: prod-old
|
||||
env:
|
||||
KUBECONFIG: .kubeconfig
|
||||
steps:
|
||||
@@ -975,8 +983,8 @@ jobs:
|
||||
- name: Re-deploy proxy
|
||||
run: |
|
||||
DOCKER_TAG=${{needs.tag.outputs.build-tag}}
|
||||
helm upgrade ${{ matrix.proxy_job }} neondatabase/neon-proxy --namespace neon-proxy --install -f .github/helm-values/${{ matrix.proxy_config }}.yaml --set image.tag=${DOCKER_TAG} --set settings.sentryUrl=${{ secrets.SENTRY_URL_PROXY }} --wait --timeout 15m0s
|
||||
helm upgrade ${{ matrix.proxy_job }}-scram neondatabase/neon-proxy --namespace neon-proxy --install -f .github/helm-values/${{ matrix.proxy_config }}-scram.yaml --set image.tag=${DOCKER_TAG} --set settings.sentryUrl=${{ secrets.SENTRY_URL_PROXY }} --wait --timeout 15m0s
|
||||
helm upgrade ${{ matrix.proxy_job }} neondatabase/neon-proxy --namespace neon-proxy --install --atomic -f .github/helm-values/${{ matrix.proxy_config }}.yaml --set image.tag=${DOCKER_TAG} --set settings.sentryUrl=${{ secrets.SENTRY_URL_PROXY }} --wait --timeout 15m0s
|
||||
helm upgrade ${{ matrix.proxy_job }}-scram neondatabase/neon-proxy --namespace neon-proxy --install --atomic -f .github/helm-values/${{ matrix.proxy_config }}-scram.yaml --set image.tag=${DOCKER_TAG} --set settings.sentryUrl=${{ secrets.SENTRY_URL_PROXY }} --wait --timeout 15m0s
|
||||
|
||||
deploy-storage-broker:
|
||||
name: deploy storage broker on old staging and old prod
|
||||
@@ -993,6 +1001,8 @@ jobs:
|
||||
strategy:
|
||||
matrix:
|
||||
include: ${{fromJSON(needs.calculate-deploy-targets.outputs.matrix-include)}}
|
||||
environment:
|
||||
name: prod-old
|
||||
env:
|
||||
KUBECONFIG: .kubeconfig
|
||||
steps:
|
||||
@@ -1041,6 +1051,8 @@ jobs:
|
||||
target_cluster: dev-eu-west-1-zeta
|
||||
deploy_link_proxy: false
|
||||
deploy_legacy_scram_proxy: false
|
||||
environment:
|
||||
name: dev-${{ matrix.target_region }}
|
||||
steps:
|
||||
- name: Checkout
|
||||
uses: actions/checkout@v3
|
||||
@@ -1056,19 +1068,19 @@ jobs:
|
||||
- name: Re-deploy scram proxy
|
||||
run: |
|
||||
DOCKER_TAG=${{needs.tag.outputs.build-tag}}
|
||||
helm upgrade neon-proxy-scram neondatabase/neon-proxy --namespace neon-proxy --create-namespace --install -f .github/helm-values/${{ matrix.target_cluster }}.neon-proxy-scram.yaml --set image.tag=${DOCKER_TAG} --set settings.sentryUrl=${{ secrets.SENTRY_URL_PROXY }} --wait --timeout 15m0s
|
||||
helm upgrade neon-proxy-scram neondatabase/neon-proxy --namespace neon-proxy --create-namespace --install --atomic -f .github/helm-values/${{ matrix.target_cluster }}.neon-proxy-scram.yaml --set image.tag=${DOCKER_TAG} --set settings.sentryUrl=${{ secrets.SENTRY_URL_PROXY }} --wait --timeout 15m0s
|
||||
|
||||
- name: Re-deploy link proxy
|
||||
if: matrix.deploy_link_proxy
|
||||
run: |
|
||||
DOCKER_TAG=${{needs.tag.outputs.build-tag}}
|
||||
helm upgrade neon-proxy-link neondatabase/neon-proxy --namespace neon-proxy --create-namespace --install -f .github/helm-values/${{ matrix.target_cluster }}.neon-proxy-link.yaml --set image.tag=${DOCKER_TAG} --set settings.sentryUrl=${{ secrets.SENTRY_URL_PROXY }} --wait --timeout 15m0s
|
||||
helm upgrade neon-proxy-link neondatabase/neon-proxy --namespace neon-proxy --create-namespace --install --atomic -f .github/helm-values/${{ matrix.target_cluster }}.neon-proxy-link.yaml --set image.tag=${DOCKER_TAG} --set settings.sentryUrl=${{ secrets.SENTRY_URL_PROXY }} --wait --timeout 15m0s
|
||||
|
||||
- name: Re-deploy legacy scram proxy
|
||||
if: matrix.deploy_legacy_scram_proxy
|
||||
run: |
|
||||
DOCKER_TAG=${{needs.tag.outputs.build-tag}}
|
||||
helm upgrade neon-proxy-scram-legacy neondatabase/neon-proxy --namespace neon-proxy --create-namespace --install -f .github/helm-values/${{ matrix.target_cluster }}.neon-proxy-scram-legacy.yaml --set image.tag=${DOCKER_TAG} --set settings.sentryUrl=${{ secrets.SENTRY_URL_PROXY }} --wait --timeout 15m0s
|
||||
helm upgrade neon-proxy-scram-legacy neondatabase/neon-proxy --namespace neon-proxy --create-namespace --install --atomic -f .github/helm-values/${{ matrix.target_cluster }}.neon-proxy-scram-legacy.yaml --set image.tag=${DOCKER_TAG} --set settings.sentryUrl=${{ secrets.SENTRY_URL_PROXY }} --wait --timeout 15m0s
|
||||
|
||||
deploy-storage-broker-dev-new:
|
||||
runs-on: [ self-hosted, dev, x64 ]
|
||||
@@ -1088,6 +1100,8 @@ jobs:
|
||||
target_cluster: dev-us-east-2-beta
|
||||
- target_region: eu-west-1
|
||||
target_cluster: dev-eu-west-1-zeta
|
||||
environment:
|
||||
name: dev-${{ matrix.target_region }}
|
||||
steps:
|
||||
- name: Checkout
|
||||
uses: actions/checkout@v3
|
||||
@@ -1126,6 +1140,8 @@ jobs:
|
||||
target_cluster: prod-eu-central-1-gamma
|
||||
- target_region: ap-southeast-1
|
||||
target_cluster: prod-ap-southeast-1-epsilon
|
||||
environment:
|
||||
name: prod-${{ matrix.target_region }}
|
||||
steps:
|
||||
- name: Checkout
|
||||
uses: actions/checkout@v3
|
||||
@@ -1141,7 +1157,7 @@ jobs:
|
||||
- name: Re-deploy proxy
|
||||
run: |
|
||||
DOCKER_TAG=${{needs.tag.outputs.build-tag}}
|
||||
helm upgrade neon-proxy-scram neondatabase/neon-proxy --namespace neon-proxy --create-namespace --install -f .github/helm-values/${{ matrix.target_cluster }}.neon-proxy-scram.yaml --set image.tag=${DOCKER_TAG} --set settings.sentryUrl=${{ secrets.SENTRY_URL_PROXY }} --wait --timeout 15m0s
|
||||
helm upgrade neon-proxy-scram neondatabase/neon-proxy --namespace neon-proxy --create-namespace --install --atomic -f .github/helm-values/${{ matrix.target_cluster }}.neon-proxy-scram.yaml --set image.tag=${DOCKER_TAG} --set settings.sentryUrl=${{ secrets.SENTRY_URL_PROXY }} --wait --timeout 15m0s
|
||||
|
||||
deploy-storage-broker-prod-new:
|
||||
runs-on: prod
|
||||
@@ -1165,6 +1181,8 @@ jobs:
|
||||
target_cluster: prod-eu-central-1-gamma
|
||||
- target_region: ap-southeast-1
|
||||
target_cluster: prod-ap-southeast-1-epsilon
|
||||
environment:
|
||||
name: prod-${{ matrix.target_region }}
|
||||
steps:
|
||||
- name: Checkout
|
||||
uses: actions/checkout@v3
|
||||
|
||||
96
Cargo.lock
generated
96
Cargo.lock
generated
@@ -1700,6 +1700,19 @@ dependencies = [
|
||||
"tokio-io-timeout",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "hyper-tungstenite"
|
||||
version = "0.8.2"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "d62004bcd4f6f85d9e2aa4206f1466ee67031f5ededcb6c6e62d48f9306ad879"
|
||||
dependencies = [
|
||||
"hyper",
|
||||
"pin-project",
|
||||
"tokio",
|
||||
"tokio-tungstenite",
|
||||
"tungstenite",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "iana-time-zone"
|
||||
version = "0.1.53"
|
||||
@@ -2497,12 +2510,15 @@ name = "pq_proto"
|
||||
version = "0.1.0"
|
||||
dependencies = [
|
||||
"anyhow",
|
||||
"byteorder",
|
||||
"bytes",
|
||||
"pin-project-lite",
|
||||
"postgres-protocol",
|
||||
"rand",
|
||||
"serde",
|
||||
"thiserror",
|
||||
"tokio",
|
||||
"tokio-util",
|
||||
"tracing",
|
||||
"workspace_hack",
|
||||
]
|
||||
@@ -2657,6 +2673,7 @@ dependencies = [
|
||||
"hex",
|
||||
"hmac",
|
||||
"hyper",
|
||||
"hyper-tungstenite",
|
||||
"itertools",
|
||||
"md5",
|
||||
"metrics",
|
||||
@@ -2666,6 +2683,7 @@ dependencies = [
|
||||
"pq_proto",
|
||||
"rand",
|
||||
"rcgen",
|
||||
"regex",
|
||||
"reqwest",
|
||||
"routerify",
|
||||
"rstest",
|
||||
@@ -2677,6 +2695,7 @@ dependencies = [
|
||||
"sha2",
|
||||
"socket2",
|
||||
"thiserror",
|
||||
"tls-listener",
|
||||
"tokio",
|
||||
"tokio-postgres",
|
||||
"tokio-postgres-rustls",
|
||||
@@ -2686,6 +2705,7 @@ dependencies = [
|
||||
"url",
|
||||
"utils",
|
||||
"uuid",
|
||||
"webpki-roots",
|
||||
"workspace_hack",
|
||||
"x509-parser",
|
||||
]
|
||||
@@ -3056,6 +3076,7 @@ dependencies = [
|
||||
"const_format",
|
||||
"crc32c",
|
||||
"fs2",
|
||||
"futures",
|
||||
"git-version",
|
||||
"hex",
|
||||
"humantime",
|
||||
@@ -3064,6 +3085,7 @@ dependencies = [
|
||||
"nix",
|
||||
"once_cell",
|
||||
"parking_lot",
|
||||
"pin-project-lite",
|
||||
"postgres",
|
||||
"postgres-protocol",
|
||||
"postgres_ffi",
|
||||
@@ -3323,6 +3345,17 @@ dependencies = [
|
||||
"syn",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "sha-1"
|
||||
version = "0.10.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "f5058ada175748e33390e40e872bd0fe59a19f265d0158daa551c5a88a76009c"
|
||||
dependencies = [
|
||||
"cfg-if",
|
||||
"cpufeatures",
|
||||
"digest",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "sha1"
|
||||
version = "0.10.5"
|
||||
@@ -3687,10 +3720,24 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "cda74da7e1a664f795bb1f8a87ec406fb89a02522cf6e50620d016add6dbbf5c"
|
||||
|
||||
[[package]]
|
||||
name = "tokio"
|
||||
version = "1.21.1"
|
||||
name = "tls-listener"
|
||||
version = "0.5.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "0020c875007ad96677dcc890298f4b942882c5d4eb7cc8f439fc3bf813dc9c95"
|
||||
checksum = "c9d4ff21187d434ac7709bfc7441ca88f63681247e5ad99f0f08c8c91ddc103d"
|
||||
dependencies = [
|
||||
"futures-util",
|
||||
"hyper",
|
||||
"pin-project-lite",
|
||||
"thiserror",
|
||||
"tokio",
|
||||
"tokio-rustls",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "tokio"
|
||||
version = "1.24.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "1d9f76183f91ecfb55e1d7d5602bd1d979e38a3a522fe900241cf195624d67ae"
|
||||
dependencies = [
|
||||
"autocfg",
|
||||
"bytes",
|
||||
@@ -3698,12 +3745,11 @@ dependencies = [
|
||||
"memchr",
|
||||
"mio",
|
||||
"num_cpus",
|
||||
"once_cell",
|
||||
"pin-project-lite",
|
||||
"signal-hook-registry",
|
||||
"socket2",
|
||||
"tokio-macros",
|
||||
"winapi",
|
||||
"windows-sys 0.42.0",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
@@ -3800,6 +3846,18 @@ dependencies = [
|
||||
"xattr",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "tokio-tungstenite"
|
||||
version = "0.17.2"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "f714dd15bead90401d77e04243611caec13726c2408afd5b31901dfcdcb3b181"
|
||||
dependencies = [
|
||||
"futures-util",
|
||||
"log",
|
||||
"tokio",
|
||||
"tungstenite",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "tokio-util"
|
||||
version = "0.7.4"
|
||||
@@ -4026,6 +4084,25 @@ version = "0.2.3"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "59547bce71d9c38b83d9c0e92b6066c4253371f15005def0c30d9657f50c7642"
|
||||
|
||||
[[package]]
|
||||
name = "tungstenite"
|
||||
version = "0.17.3"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "e27992fd6a8c29ee7eef28fc78349aa244134e10ad447ce3b9f0ac0ed0fa4ce0"
|
||||
dependencies = [
|
||||
"base64 0.13.1",
|
||||
"byteorder",
|
||||
"bytes",
|
||||
"http",
|
||||
"httparse",
|
||||
"log",
|
||||
"rand",
|
||||
"sha-1",
|
||||
"thiserror",
|
||||
"url",
|
||||
"utf-8",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "typenum"
|
||||
version = "1.16.0"
|
||||
@@ -4114,6 +4191,12 @@ version = "2.1.2"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "e8db7427f936968176eaa7cdf81b7f98b980b18495ec28f1b5791ac3bfe3eea9"
|
||||
|
||||
[[package]]
|
||||
name = "utf-8"
|
||||
version = "0.7.6"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "09cc8ee72d2a9becf2f2febe0205bbed8fc6615b7cb429ad062dc7b7ddd036a9"
|
||||
|
||||
[[package]]
|
||||
name = "utils"
|
||||
version = "0.1.0"
|
||||
@@ -4124,6 +4207,7 @@ dependencies = [
|
||||
"byteorder",
|
||||
"bytes",
|
||||
"criterion",
|
||||
"futures",
|
||||
"git-version",
|
||||
"hex",
|
||||
"hex-literal",
|
||||
@@ -4132,6 +4216,7 @@ dependencies = [
|
||||
"metrics",
|
||||
"nix",
|
||||
"once_cell",
|
||||
"pin-utils",
|
||||
"pq_proto",
|
||||
"rand",
|
||||
"routerify",
|
||||
@@ -4149,6 +4234,7 @@ dependencies = [
|
||||
"thiserror",
|
||||
"tokio",
|
||||
"tokio-rustls",
|
||||
"tokio-util",
|
||||
"tracing",
|
||||
"tracing-subscriber",
|
||||
"workspace_hack",
|
||||
|
||||
11
README.md
11
README.md
@@ -118,11 +118,8 @@ Python (3.9 or higher), and install python3 packages using `./scripts/pysync` (r
|
||||
# Later that would be responsibility of a package install script
|
||||
> ./target/debug/neon_local init
|
||||
Starting pageserver at '127.0.0.1:64000' in '.neon'.
|
||||
pageserver started, pid: 2545906
|
||||
Successfully initialized timeline de200bd42b49cc1814412c7e592dd6e9
|
||||
Stopped pageserver 1 process with pid 2545906
|
||||
|
||||
# start pageserver and safekeeper
|
||||
# start pageserver, safekeeper, and broker for their intercommunication
|
||||
> ./target/debug/neon_local start
|
||||
Starting neon broker at 127.0.0.1:50051
|
||||
storage_broker started, pid: 2918372
|
||||
@@ -131,6 +128,12 @@ pageserver started, pid: 2918386
|
||||
Starting safekeeper at '127.0.0.1:5454' in '.neon/safekeepers/sk1'.
|
||||
safekeeper 1 started, pid: 2918437
|
||||
|
||||
# create initial tenant and use it as a default for every future neon_local invocation
|
||||
> ./target/debug/neon_local tenant create --set-default
|
||||
tenant 9ef87a5bf0d92544f6fafeeb3239695c successfully created on the pageserver
|
||||
Created an initial timeline 'de200bd42b49cc1814412c7e592dd6e9' at Lsn 0/16B5A50 for tenant: 9ef87a5bf0d92544f6fafeeb3239695c
|
||||
Setting tenant 9ef87a5bf0d92544f6fafeeb3239695c as a default one
|
||||
|
||||
# start postgres compute node
|
||||
> ./target/debug/neon_local pg start main
|
||||
Starting new postgres (v14) main on timeline de200bd42b49cc1814412c7e592dd6e9 ...
|
||||
|
||||
@@ -52,10 +52,16 @@ fn watch_compute_activity(compute: &ComputeNode) {
|
||||
let mut idle_backs: Vec<DateTime<Utc>> = vec![];
|
||||
|
||||
for b in backs.into_iter() {
|
||||
let state: String = b.get("state");
|
||||
let change: String = b.get("state_change");
|
||||
let state: String = match b.try_get("state") {
|
||||
Ok(state) => state,
|
||||
Err(_) => continue,
|
||||
};
|
||||
|
||||
if state == "idle" {
|
||||
let change: String = match b.try_get("state_change") {
|
||||
Ok(state_change) => state_change,
|
||||
Err(_) => continue,
|
||||
};
|
||||
let change = DateTime::parse_from_rfc3339(&change);
|
||||
match change {
|
||||
Ok(t) => idle_backs.push(t.with_timezone(&Utc)),
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
use std::path::Path;
|
||||
use std::str::FromStr;
|
||||
use std::time::Instant;
|
||||
|
||||
use anyhow::Result;
|
||||
use log::{info, log_enabled, warn, Level};
|
||||
@@ -197,22 +198,18 @@ pub fn handle_roles(spec: &ComputeSpec, client: &mut Client) -> Result<()> {
|
||||
|
||||
/// Reassign all dependent objects and delete requested roles.
|
||||
pub fn handle_role_deletions(node: &ComputeNode, client: &mut Client) -> Result<()> {
|
||||
let spec = &node.spec;
|
||||
|
||||
// First, reassign all dependent objects to db owners.
|
||||
if let Some(ops) = &spec.delta_operations {
|
||||
if let Some(ops) = &node.spec.delta_operations {
|
||||
// First, reassign all dependent objects to db owners.
|
||||
info!("reassigning dependent objects of to-be-deleted roles");
|
||||
for op in ops {
|
||||
if op.action == "delete_role" {
|
||||
reassign_owned_objects(node, &op.name)?;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Second, proceed with role deletions.
|
||||
let mut xact = client.transaction()?;
|
||||
if let Some(ops) = &spec.delta_operations {
|
||||
// Second, proceed with role deletions.
|
||||
info!("processing role deletions");
|
||||
let mut xact = client.transaction()?;
|
||||
for op in ops {
|
||||
// We do not check either role exists or not,
|
||||
// Postgres will take care of it for us
|
||||
@@ -223,6 +220,7 @@ pub fn handle_role_deletions(node: &ComputeNode, client: &mut Client) -> Result<
|
||||
xact.execute(query.as_str(), &[])?;
|
||||
}
|
||||
}
|
||||
xact.commit()?;
|
||||
}
|
||||
|
||||
Ok(())
|
||||
@@ -317,6 +315,7 @@ pub fn handle_databases(spec: &ComputeSpec, client: &mut Client) -> Result<()> {
|
||||
// XXX: with a limited number of databases it is fine, but consider making it a HashMap
|
||||
let pg_db = existing_dbs.iter().find(|r| r.name == *name);
|
||||
|
||||
let start_time = Instant::now();
|
||||
if let Some(r) = pg_db {
|
||||
// XXX: db owner name is returned as quoted string from Postgres,
|
||||
// when quoting is needed.
|
||||
@@ -335,6 +334,8 @@ pub fn handle_databases(spec: &ComputeSpec, client: &mut Client) -> Result<()> {
|
||||
info_print!(" -> update");
|
||||
|
||||
client.execute(query.as_str(), &[])?;
|
||||
let elapsed = start_time.elapsed().as_millis();
|
||||
info_print!(" ({} ms)", elapsed);
|
||||
}
|
||||
} else {
|
||||
let mut query: String = format!("CREATE DATABASE {} ", name.pg_quote());
|
||||
@@ -342,6 +343,9 @@ pub fn handle_databases(spec: &ComputeSpec, client: &mut Client) -> Result<()> {
|
||||
|
||||
query.push_str(&db.to_pg_options());
|
||||
client.execute(query.as_str(), &[])?;
|
||||
|
||||
let elapsed = start_time.elapsed().as_millis();
|
||||
info_print!(" ({} ms)", elapsed);
|
||||
}
|
||||
|
||||
info_print!("\n");
|
||||
|
||||
@@ -136,22 +136,6 @@ where
|
||||
anyhow::bail!("{process_name} did not start in {RETRY_UNTIL_SECS} seconds");
|
||||
}
|
||||
|
||||
/// Send SIGTERM to child process
|
||||
pub fn send_stop_child_process(child: &std::process::Child) -> anyhow::Result<()> {
|
||||
let pid = child.id();
|
||||
match kill(
|
||||
nix::unistd::Pid::from_raw(pid.try_into().unwrap()),
|
||||
Signal::SIGTERM,
|
||||
) {
|
||||
Ok(()) => Ok(()),
|
||||
Err(Errno::ESRCH) => {
|
||||
println!("child process with pid {pid} does not exist");
|
||||
Ok(())
|
||||
}
|
||||
Err(e) => anyhow::bail!("Failed to send signal to child process with pid {pid}: {e}"),
|
||||
}
|
||||
}
|
||||
|
||||
/// Stops the process, using the pid file given. Returns Ok also if the process is already not running.
|
||||
pub fn stop_process(immediate: bool, process_name: &str, pid_file: &Path) -> anyhow::Result<()> {
|
||||
let pid = match pid_file::read(pid_file)
|
||||
|
||||
@@ -263,7 +263,7 @@ fn get_tenant_id(sub_match: &ArgMatches, env: &local_env::LocalEnv) -> anyhow::R
|
||||
} else if let Some(default_id) = env.default_tenant_id {
|
||||
Ok(default_id)
|
||||
} else {
|
||||
bail!("No tenant id. Use --tenant-id, or set 'default_tenant_id' in the config file");
|
||||
anyhow::bail!("No tenant id. Use --tenant-id, or set a default tenant");
|
||||
}
|
||||
}
|
||||
|
||||
@@ -284,8 +284,6 @@ fn parse_timeline_id(sub_match: &ArgMatches) -> anyhow::Result<Option<TimelineId
|
||||
}
|
||||
|
||||
fn handle_init(init_match: &ArgMatches) -> anyhow::Result<LocalEnv> {
|
||||
let initial_timeline_id_arg = parse_timeline_id(init_match)?;
|
||||
|
||||
// Create config file
|
||||
let toml_file: String = if let Some(config_path) = init_match.get_one::<PathBuf>("config") {
|
||||
// load and parse the file
|
||||
@@ -309,30 +307,16 @@ fn handle_init(init_match: &ArgMatches) -> anyhow::Result<LocalEnv> {
|
||||
LocalEnv::parse_config(&toml_file).context("Failed to create neon configuration")?;
|
||||
env.init(pg_version)
|
||||
.context("Failed to initialize neon repository")?;
|
||||
let initial_tenant_id = env
|
||||
.default_tenant_id
|
||||
.expect("default_tenant_id should be generated by the `env.init()` call above");
|
||||
|
||||
// Initialize pageserver, create initial tenant and timeline.
|
||||
let pageserver = PageServerNode::from_env(&env);
|
||||
let initial_timeline_id = pageserver
|
||||
.initialize(
|
||||
Some(initial_tenant_id),
|
||||
initial_timeline_id_arg,
|
||||
&pageserver_config_overrides(init_match),
|
||||
pg_version,
|
||||
)
|
||||
pageserver
|
||||
.initialize(&pageserver_config_overrides(init_match))
|
||||
.unwrap_or_else(|e| {
|
||||
eprintln!("pageserver init failed: {e:?}");
|
||||
exit(1);
|
||||
});
|
||||
|
||||
env.register_branch_mapping(
|
||||
DEFAULT_BRANCH_NAME.to_owned(),
|
||||
initial_tenant_id,
|
||||
initial_timeline_id,
|
||||
)?;
|
||||
|
||||
Ok(env)
|
||||
}
|
||||
|
||||
@@ -388,6 +372,17 @@ fn handle_tenant(tenant_match: &ArgMatches, env: &mut local_env::LocalEnv) -> an
|
||||
println!(
|
||||
"Created an initial timeline '{new_timeline_id}' at Lsn {last_record_lsn} for tenant: {new_tenant_id}",
|
||||
);
|
||||
|
||||
if create_match.get_flag("set-default") {
|
||||
println!("Setting tenant {new_tenant_id} as a default one");
|
||||
env.default_tenant_id = Some(new_tenant_id);
|
||||
}
|
||||
}
|
||||
Some(("set-default", set_default_match)) => {
|
||||
let tenant_id =
|
||||
parse_tenant_id(set_default_match)?.context("No tenant id specified")?;
|
||||
println!("Setting tenant {tenant_id} as a default one");
|
||||
env.default_tenant_id = Some(tenant_id);
|
||||
}
|
||||
Some(("config", create_match)) => {
|
||||
let tenant_id = get_tenant_id(create_match, env)?;
|
||||
@@ -928,9 +923,8 @@ fn cli() -> Command {
|
||||
.version(GIT_VERSION)
|
||||
.subcommand(
|
||||
Command::new("init")
|
||||
.about("Initialize a new Neon repository")
|
||||
.about("Initialize a new Neon repository, preparing configs for services to start with")
|
||||
.arg(pageserver_config_args.clone())
|
||||
.arg(timeline_id_arg.clone().help("Use a specific timeline id when creating a tenant and its initial timeline"))
|
||||
.arg(
|
||||
Arg::new("config")
|
||||
.long("config")
|
||||
@@ -992,11 +986,14 @@ fn cli() -> Command {
|
||||
.arg(timeline_id_arg.clone().help("Use a specific timeline id when creating a tenant and its initial timeline"))
|
||||
.arg(Arg::new("config").short('c').num_args(1).action(ArgAction::Append).required(false))
|
||||
.arg(pg_version_arg.clone())
|
||||
.arg(Arg::new("set-default").long("set-default").action(ArgAction::SetTrue).required(false)
|
||||
.help("Use this tenant in future CLI commands where tenant_id is needed, but not specified"))
|
||||
)
|
||||
.subcommand(Command::new("set-default").arg(tenant_id_arg.clone().required(true))
|
||||
.about("Set a particular tenant as default in future CLI commands where tenant_id is needed, but not specified"))
|
||||
.subcommand(Command::new("config")
|
||||
.arg(tenant_id_arg.clone())
|
||||
.arg(Arg::new("config").short('c').num_args(1).action(ArgAction::Append).required(false))
|
||||
)
|
||||
.arg(Arg::new("config").short('c').num_args(1).action(ArgAction::Append).required(false)))
|
||||
)
|
||||
.subcommand(
|
||||
Command::new("pageserver")
|
||||
|
||||
@@ -14,7 +14,7 @@ use anyhow::{Context, Result};
|
||||
use utils::{
|
||||
id::{TenantId, TimelineId},
|
||||
lsn::Lsn,
|
||||
postgres_backend::AuthType,
|
||||
postgres_backend_async::AuthType,
|
||||
};
|
||||
|
||||
use crate::local_env::{LocalEnv, DEFAULT_PG_VERSION};
|
||||
|
||||
@@ -19,7 +19,7 @@ use std::process::{Command, Stdio};
|
||||
use utils::{
|
||||
auth::{encode_from_key_file, Claims, Scope},
|
||||
id::{NodeId, TenantId, TenantTimelineId, TimelineId},
|
||||
postgres_backend::AuthType,
|
||||
postgres_backend_async::AuthType,
|
||||
};
|
||||
|
||||
use crate::safekeeper::SafekeeperNode;
|
||||
@@ -296,11 +296,6 @@ impl LocalEnv {
|
||||
env.neon_distrib_dir = env::current_exe()?.parent().unwrap().to_owned();
|
||||
}
|
||||
|
||||
// If no initial tenant ID was given, generate it.
|
||||
if env.default_tenant_id.is_none() {
|
||||
env.default_tenant_id = Some(TenantId::generate());
|
||||
}
|
||||
|
||||
env.base_data_dir = base_path();
|
||||
|
||||
Ok(env)
|
||||
|
||||
@@ -7,7 +7,7 @@ use std::path::PathBuf;
|
||||
use std::process::{Child, Command};
|
||||
use std::{io, result};
|
||||
|
||||
use anyhow::{bail, ensure, Context};
|
||||
use anyhow::{bail, Context};
|
||||
use pageserver_api::models::{
|
||||
TenantConfigRequest, TenantCreateRequest, TenantInfo, TimelineCreateRequest, TimelineInfo,
|
||||
};
|
||||
@@ -130,83 +130,15 @@ impl PageServerNode {
|
||||
overrides
|
||||
}
|
||||
|
||||
/// Initializes a pageserver node by creating its config with the overrides provided,
|
||||
/// and creating an initial tenant and timeline afterwards.
|
||||
pub fn initialize(
|
||||
&self,
|
||||
create_tenant: Option<TenantId>,
|
||||
initial_timeline_id: Option<TimelineId>,
|
||||
config_overrides: &[&str],
|
||||
pg_version: u32,
|
||||
) -> anyhow::Result<TimelineId> {
|
||||
/// Initializes a pageserver node by creating its config with the overrides provided.
|
||||
pub fn initialize(&self, config_overrides: &[&str]) -> anyhow::Result<()> {
|
||||
// First, run `pageserver --init` and wait for it to write a config into FS and exit.
|
||||
self.pageserver_init(config_overrides).with_context(|| {
|
||||
format!(
|
||||
"Failed to run init for pageserver node {}",
|
||||
self.env.pageserver.id,
|
||||
)
|
||||
})?;
|
||||
|
||||
// Then, briefly start it fully to run HTTP commands on it,
|
||||
// to create initial tenant and timeline.
|
||||
// We disable the remote storage, since we stop pageserver right after the timeline creation,
|
||||
// hence most of the uploads will either aborted or not started: no point to start them at all.
|
||||
let disabled_remote_storage_override = "remote_storage={}";
|
||||
let mut pageserver_process = self
|
||||
.start_node(
|
||||
&[disabled_remote_storage_override],
|
||||
// Previous overrides will be taken from the config created before, don't overwrite them.
|
||||
false,
|
||||
)
|
||||
.with_context(|| {
|
||||
format!(
|
||||
"Failed to start a process for pageserver node {}",
|
||||
self.env.pageserver.id,
|
||||
)
|
||||
})?;
|
||||
|
||||
let init_result = self
|
||||
.try_init_timeline(create_tenant, initial_timeline_id, pg_version)
|
||||
.context("Failed to create initial tenant and timeline for pageserver");
|
||||
match &init_result {
|
||||
Ok(initial_timeline_id) => {
|
||||
println!("Successfully initialized timeline {initial_timeline_id}")
|
||||
}
|
||||
Err(e) => eprintln!("{e:#}"),
|
||||
}
|
||||
background_process::send_stop_child_process(&pageserver_process)?;
|
||||
|
||||
let exit_code = pageserver_process.wait()?;
|
||||
ensure!(
|
||||
exit_code.success(),
|
||||
format!(
|
||||
"pageserver init failed with exit code {:?}",
|
||||
exit_code.code()
|
||||
)
|
||||
);
|
||||
println!(
|
||||
"Stopped pageserver {} process with pid {}",
|
||||
self.env.pageserver.id,
|
||||
pageserver_process.id(),
|
||||
);
|
||||
init_result
|
||||
}
|
||||
|
||||
fn try_init_timeline(
|
||||
&self,
|
||||
new_tenant_id: Option<TenantId>,
|
||||
new_timeline_id: Option<TimelineId>,
|
||||
pg_version: u32,
|
||||
) -> anyhow::Result<TimelineId> {
|
||||
let initial_tenant_id = self.tenant_create(new_tenant_id, HashMap::new())?;
|
||||
let initial_timeline_info = self.timeline_create(
|
||||
initial_tenant_id,
|
||||
new_timeline_id,
|
||||
None,
|
||||
None,
|
||||
Some(pg_version),
|
||||
)?;
|
||||
Ok(initial_timeline_info.timeline_id)
|
||||
})
|
||||
}
|
||||
|
||||
pub fn repo_path(&self) -> PathBuf {
|
||||
|
||||
@@ -7,11 +7,14 @@ license = "Apache-2.0"
|
||||
[dependencies]
|
||||
anyhow = "1.0"
|
||||
bytes = "1.0.1"
|
||||
byteorder = "1.4.3"
|
||||
pin-project-lite = "0.2.7"
|
||||
postgres-protocol = { git = "https://github.com/neondatabase/rust-postgres.git", rev="43e6db254a97fdecbce33d8bc0890accfd74495e" }
|
||||
rand = "0.8.3"
|
||||
serde = { version = "1.0", features = ["derive"] }
|
||||
tokio = { version = "1.17", features = ["macros"] }
|
||||
tokio-util = { version = "0.7.3" }
|
||||
tracing = "0.1"
|
||||
thiserror = "1.0"
|
||||
|
||||
workspace_hack = { version = "0.1", path = "../../workspace_hack" }
|
||||
|
||||
62
libs/pq_proto/src/codec.rs
Normal file
62
libs/pq_proto/src/codec.rs
Normal file
@@ -0,0 +1,62 @@
|
||||
//! Provides `PostgresCodec` defining how to serilize/deserialize Postgres
|
||||
//! messages to/from the wire, to be used with `tokio_util::codec::Framed`.
|
||||
use std::io;
|
||||
|
||||
use bytes::BytesMut;
|
||||
use tokio_util::codec::{Decoder, Encoder};
|
||||
|
||||
use crate::{BeMessage, FeMessage, FeStartupPacket, ProtocolError};
|
||||
|
||||
// Defines how to serilize/deserialize Postgres messages to/from the wire, to be
|
||||
// used with `tokio_util::codec::Framed`.
|
||||
pub struct PostgresCodec {
|
||||
// Have we already decoded startup message? All further should start with
|
||||
// message type byte then.
|
||||
startup_read: bool,
|
||||
}
|
||||
|
||||
impl PostgresCodec {
|
||||
pub fn new() -> Self {
|
||||
PostgresCodec {
|
||||
startup_read: false,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Error on postgres connection: either IO (physical transport error) or
|
||||
/// protocol violation.
|
||||
#[derive(thiserror::Error, Debug)]
|
||||
pub enum ConnectionError {
|
||||
#[error(transparent)]
|
||||
Io(#[from] io::Error),
|
||||
#[error(transparent)]
|
||||
Protocol(#[from] ProtocolError),
|
||||
}
|
||||
|
||||
impl Encoder<&BeMessage<'_>> for PostgresCodec {
|
||||
type Error = ConnectionError;
|
||||
|
||||
fn encode(&mut self, item: &BeMessage, dst: &mut BytesMut) -> Result<(), ConnectionError> {
|
||||
BeMessage::write(dst, &item)?;
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
impl Decoder for PostgresCodec {
|
||||
type Item = FeMessage;
|
||||
type Error = ConnectionError;
|
||||
|
||||
fn decode(&mut self, src: &mut BytesMut) -> Result<Option<FeMessage>, ConnectionError> {
|
||||
let msg = if !self.startup_read {
|
||||
let msg = FeStartupPacket::parse(src);
|
||||
if let Ok(Some(FeMessage::StartupPacket(FeStartupPacket::StartupMessage { .. }))) = msg
|
||||
{
|
||||
self.startup_read = true;
|
||||
}
|
||||
msg?
|
||||
} else {
|
||||
FeMessage::parse(src)?
|
||||
};
|
||||
Ok(msg)
|
||||
}
|
||||
}
|
||||
@@ -3,9 +3,11 @@
|
||||
//! on message formats.
|
||||
|
||||
// Tools for calling certain async methods in sync contexts.
|
||||
pub mod codec;
|
||||
pub mod sync;
|
||||
|
||||
use anyhow::{bail, ensure, Context, Result};
|
||||
use anyhow::{anyhow, bail, ensure, Context, Result};
|
||||
use byteorder::{BigEndian, ByteOrder, ReadBytesExt};
|
||||
use bytes::{Buf, BufMut, Bytes, BytesMut};
|
||||
use postgres_protocol::PG_EPOCH;
|
||||
use serde::{Deserialize, Serialize};
|
||||
@@ -19,7 +21,7 @@ use std::{
|
||||
time::{Duration, SystemTime},
|
||||
};
|
||||
use sync::{AsyncishRead, SyncFuture};
|
||||
use tokio::io::AsyncReadExt;
|
||||
// use tokio::io::AsyncReadExt;
|
||||
use tracing::{trace, warn};
|
||||
|
||||
pub type Oid = u32;
|
||||
@@ -194,7 +196,108 @@ macro_rules! retry_read {
|
||||
};
|
||||
}
|
||||
|
||||
/// An error occured while parsing or serializing raw stream into Postgres
|
||||
/// messages.
|
||||
#[derive(thiserror::Error, Debug)]
|
||||
pub enum ProtocolError {
|
||||
/// IO error during writing to or reading from the connection socket.
|
||||
/// removeme
|
||||
#[error("Socket IO error: {0}")]
|
||||
Socket(std::io::Error),
|
||||
/// Invalid packet was received from the client (e.g. unexpected message
|
||||
/// type or broken len).
|
||||
#[error("Protocol error: {0}")]
|
||||
Protocol(String),
|
||||
/// Failed to parse or, (unlikely), serialize a protocol message.
|
||||
#[error("Message parse error: {0}")]
|
||||
MessageParse(anyhow::Error),
|
||||
}
|
||||
|
||||
// Allows to return anyhow error from msg parsing routines, meaning less typing.
|
||||
impl From<anyhow::Error> for ProtocolError {
|
||||
fn from(e: anyhow::Error) -> Self {
|
||||
Self::MessageParse(e)
|
||||
}
|
||||
}
|
||||
|
||||
impl ProtocolError {
|
||||
pub fn into_io_error(self) -> io::Error {
|
||||
match self {
|
||||
ProtocolError::Socket(io) => io,
|
||||
other => io::Error::new(io::ErrorKind::Other, other.to_string()),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl FeMessage {
|
||||
/// Read and parse one message from the `buf` input buffer. If there is at
|
||||
/// least one valid message, returns it, advancing `buf`; redundant copies
|
||||
/// are avoided, as thanks to `bytes` crate ptrs in parsed message point
|
||||
/// directly into the `buf` (processed data is garbage collected after
|
||||
/// parsed message is dropped).
|
||||
///
|
||||
/// Returns None if `buf` doesn't contain enough data for a single message.
|
||||
/// For efficiency, tries to reserve large enough space in `buf` for the
|
||||
/// next message in this case.
|
||||
///
|
||||
/// Returns Error if message is malformed, the only possible ErrorKind is
|
||||
/// InvalidInput.
|
||||
//
|
||||
// Inspired by rust-postgres Message::parse.
|
||||
pub fn parse(buf: &mut BytesMut) -> Result<Option<FeMessage>, ProtocolError> {
|
||||
// Every message contains message type byte and 4 bytes len; can't do
|
||||
// much without them.
|
||||
if buf.len() < 5 {
|
||||
let to_read = 5 - buf.len();
|
||||
buf.reserve(to_read);
|
||||
return Ok(None);
|
||||
}
|
||||
|
||||
// We shouldn't advance `buf` as probably full message is not there yet,
|
||||
// so can't directly use Bytes::get_u32 etc.
|
||||
let tag = buf[0];
|
||||
let len = (&buf[1..5]).read_u32::<BigEndian>().unwrap();
|
||||
if len < 4 {
|
||||
return Err(ProtocolError::Protocol(format!(
|
||||
"invalid message length {}",
|
||||
len
|
||||
)));
|
||||
}
|
||||
|
||||
// lenth field includes itself, but not message type.
|
||||
let total_len = len as usize + 1;
|
||||
if buf.len() < total_len {
|
||||
// Don't have full message yet.
|
||||
let to_read = total_len - buf.len();
|
||||
buf.reserve(to_read);
|
||||
return Ok(None);
|
||||
}
|
||||
|
||||
// got the message, advance buffer
|
||||
let mut msg = buf.split_to(total_len).freeze();
|
||||
msg.advance(5); // consume message type and len
|
||||
|
||||
match tag {
|
||||
b'Q' => Ok(Some(FeMessage::Query(msg))),
|
||||
b'P' => Ok(Some(FeParseMessage::parse(msg)?)),
|
||||
b'D' => Ok(Some(FeDescribeMessage::parse(msg)?)),
|
||||
b'E' => Ok(Some(FeExecuteMessage::parse(msg)?)),
|
||||
b'B' => Ok(Some(FeBindMessage::parse(msg)?)),
|
||||
b'C' => Ok(Some(FeCloseMessage::parse(msg)?)),
|
||||
b'S' => Ok(Some(FeMessage::Sync)),
|
||||
b'X' => Ok(Some(FeMessage::Terminate)),
|
||||
b'd' => Ok(Some(FeMessage::CopyData(msg))),
|
||||
b'c' => Ok(Some(FeMessage::CopyDone)),
|
||||
b'f' => Ok(Some(FeMessage::CopyFail)),
|
||||
b'p' => Ok(Some(FeMessage::PasswordMessage(msg))),
|
||||
tag => {
|
||||
return Err(ProtocolError::Protocol(format!(
|
||||
"unknown message tag: {tag},'{msg:?}'"
|
||||
)))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Read one message from the stream.
|
||||
/// This function returns `Ok(None)` in case of EOF.
|
||||
/// One way to handle this properly:
|
||||
@@ -216,58 +319,8 @@ impl FeMessage {
|
||||
/// }
|
||||
/// ```
|
||||
#[inline(never)]
|
||||
pub fn read(stream: &mut (impl io::Read + Unpin)) -> anyhow::Result<Option<FeMessage>> {
|
||||
Self::read_fut(&mut AsyncishRead(stream)).wait()
|
||||
}
|
||||
|
||||
/// Read one message from the stream.
|
||||
/// See documentation for `Self::read`.
|
||||
pub fn read_fut<Reader>(
|
||||
stream: &mut Reader,
|
||||
) -> SyncFuture<Reader, impl Future<Output = anyhow::Result<Option<FeMessage>>> + '_>
|
||||
where
|
||||
Reader: tokio::io::AsyncRead + Unpin,
|
||||
{
|
||||
// We return a Future that's sync (has a `wait` method) if and only if the provided stream is SyncProof.
|
||||
// SyncFuture contract: we are only allowed to await on sync-proof futures, the AsyncRead and
|
||||
// AsyncReadExt methods of the stream.
|
||||
SyncFuture::new(async move {
|
||||
// Each libpq message begins with a message type byte, followed by message length
|
||||
// If the client closes the connection, return None. But if the client closes the
|
||||
// connection in the middle of a message, we will return an error.
|
||||
let tag = match retry_read!(stream.read_u8().await) {
|
||||
Ok(b) => b,
|
||||
Err(e) if e.kind() == io::ErrorKind::UnexpectedEof => return Ok(None),
|
||||
Err(e) => return Err(e.into()),
|
||||
};
|
||||
|
||||
// The message length includes itself, so it better be at least 4.
|
||||
let len = retry_read!(stream.read_u32().await)?
|
||||
.checked_sub(4)
|
||||
.context("invalid message length")?;
|
||||
|
||||
let body = {
|
||||
let mut buffer = vec![0u8; len as usize];
|
||||
stream.read_exact(&mut buffer).await?;
|
||||
Bytes::from(buffer)
|
||||
};
|
||||
|
||||
match tag {
|
||||
b'Q' => Ok(Some(FeMessage::Query(body))),
|
||||
b'P' => Ok(Some(FeParseMessage::parse(body)?)),
|
||||
b'D' => Ok(Some(FeDescribeMessage::parse(body)?)),
|
||||
b'E' => Ok(Some(FeExecuteMessage::parse(body)?)),
|
||||
b'B' => Ok(Some(FeBindMessage::parse(body)?)),
|
||||
b'C' => Ok(Some(FeCloseMessage::parse(body)?)),
|
||||
b'S' => Ok(Some(FeMessage::Sync)),
|
||||
b'X' => Ok(Some(FeMessage::Terminate)),
|
||||
b'd' => Ok(Some(FeMessage::CopyData(body))),
|
||||
b'c' => Ok(Some(FeMessage::CopyDone)),
|
||||
b'f' => Ok(Some(FeMessage::CopyFail)),
|
||||
b'p' => Ok(Some(FeMessage::PasswordMessage(body))),
|
||||
tag => bail!("unknown message tag: {},'{:?}'", tag, body),
|
||||
}
|
||||
})
|
||||
pub fn read(_stream: &mut (impl io::Read + Unpin)) -> Result<Option<FeMessage>, ProtocolError> {
|
||||
Ok(None) // removeme
|
||||
}
|
||||
}
|
||||
|
||||
@@ -275,19 +328,124 @@ impl FeStartupPacket {
|
||||
/// Read startup message from the stream.
|
||||
// XXX: It's tempting yet undesirable to accept `stream` by value,
|
||||
// since such a change will cause user-supplied &mut references to be consumed
|
||||
pub fn read(stream: &mut (impl io::Read + Unpin)) -> anyhow::Result<Option<FeMessage>> {
|
||||
pub fn read(stream: &mut (impl io::Read + Unpin)) -> Result<Option<FeMessage>, ProtocolError> {
|
||||
Self::read_fut(&mut AsyncishRead(stream)).wait()
|
||||
}
|
||||
|
||||
/// Read and parse startup message from the `buf` input buffer. It is
|
||||
/// different from [`FeMessage::parse`] because startup messages don't have
|
||||
/// message type byte; otherwise, its comments apply.
|
||||
pub fn parse(buf: &mut BytesMut) -> Result<Option<FeMessage>, ProtocolError> {
|
||||
const MAX_STARTUP_PACKET_LENGTH: usize = 10000;
|
||||
const RESERVED_INVALID_MAJOR_VERSION: u32 = 1234;
|
||||
const CANCEL_REQUEST_CODE: u32 = 5678;
|
||||
const NEGOTIATE_SSL_CODE: u32 = 5679;
|
||||
const NEGOTIATE_GSS_CODE: u32 = 5680;
|
||||
|
||||
if buf.len() < 4 {
|
||||
let to_read = 5 - buf.len();
|
||||
buf.reserve(to_read);
|
||||
return Ok(None);
|
||||
}
|
||||
|
||||
// We shouldn't advance `buf` as probably full message is not there yet,
|
||||
// so can't directly use Bytes::get_u32 etc.
|
||||
let len = (&buf[0..4]).read_u32::<BigEndian>().unwrap() as usize;
|
||||
if len < 8 || len > MAX_STARTUP_PACKET_LENGTH {
|
||||
return Err(ProtocolError::Protocol(format!(
|
||||
"invalid startup packet message length {}",
|
||||
len
|
||||
)));
|
||||
}
|
||||
|
||||
if buf.len() < len {
|
||||
// Don't have full message yet.
|
||||
let to_read = len - buf.len();
|
||||
buf.reserve(to_read);
|
||||
return Ok(None);
|
||||
}
|
||||
|
||||
// got the message, advance buffer
|
||||
let mut msg = buf.split_to(len).freeze();
|
||||
msg.advance(4); // consume len
|
||||
|
||||
let request_code = msg.get_u32();
|
||||
let req_hi = request_code >> 16;
|
||||
let req_lo = request_code & ((1 << 16) - 1);
|
||||
// StartupMessage, CancelRequest, SSLRequest etc are differentiated by request code.
|
||||
let message = match (req_hi, req_lo) {
|
||||
(RESERVED_INVALID_MAJOR_VERSION, CANCEL_REQUEST_CODE) => {
|
||||
if msg.remaining() < 8 {
|
||||
return Err(ProtocolError::MessageParse(anyhow!(
|
||||
"CancelRequest message is malformed, backend PID / secret key missing"
|
||||
)));
|
||||
}
|
||||
FeStartupPacket::CancelRequest(CancelKeyData {
|
||||
backend_pid: msg.get_i32(),
|
||||
cancel_key: msg.get_i32(),
|
||||
})
|
||||
}
|
||||
(RESERVED_INVALID_MAJOR_VERSION, NEGOTIATE_SSL_CODE) => {
|
||||
// Requested upgrade to SSL (aka TLS)
|
||||
FeStartupPacket::SslRequest
|
||||
}
|
||||
(RESERVED_INVALID_MAJOR_VERSION, NEGOTIATE_GSS_CODE) => {
|
||||
// Requested upgrade to GSSAPI
|
||||
FeStartupPacket::GssEncRequest
|
||||
}
|
||||
(RESERVED_INVALID_MAJOR_VERSION, unrecognized_code) => {
|
||||
return Err(ProtocolError::Protocol(format!(
|
||||
"Unrecognized request code {unrecognized_code}"
|
||||
)));
|
||||
}
|
||||
// TODO bail if protocol major_version is not 3?
|
||||
(major_version, minor_version) => {
|
||||
// StartupMessage
|
||||
|
||||
// Parse pairs of null-terminated strings (key, value).
|
||||
// See `postgres: ProcessStartupPacket, build_startup_packet`.
|
||||
let mut tokens = str::from_utf8(&msg)
|
||||
.context("StartupMessage params: invalid utf-8")?
|
||||
.strip_suffix('\0') // drop packet's own null
|
||||
.ok_or_else(|| {
|
||||
ProtocolError::Protocol(
|
||||
"StartupMessage params: missing null terminator".to_string(),
|
||||
)
|
||||
})?
|
||||
.split_terminator('\0');
|
||||
|
||||
let mut params = HashMap::new();
|
||||
while let Some(name) = tokens.next() {
|
||||
let value = tokens.next().ok_or_else(|| {
|
||||
ProtocolError::Protocol(
|
||||
"StartupMessage params: key without value".to_string(),
|
||||
)
|
||||
})?;
|
||||
|
||||
params.insert(name.to_owned(), value.to_owned());
|
||||
}
|
||||
|
||||
FeStartupPacket::StartupMessage {
|
||||
major_version,
|
||||
minor_version,
|
||||
params: StartupMessageParams { params },
|
||||
}
|
||||
}
|
||||
};
|
||||
Ok(Some(FeMessage::StartupPacket(message)))
|
||||
}
|
||||
|
||||
/// Read startup message from the stream.
|
||||
// XXX: It's tempting yet undesirable to accept `stream` by value,
|
||||
// since such a change will cause user-supplied &mut references to be consumed
|
||||
pub fn read_fut<Reader>(
|
||||
stream: &mut Reader,
|
||||
) -> SyncFuture<Reader, impl Future<Output = anyhow::Result<Option<FeMessage>>> + '_>
|
||||
) -> SyncFuture<Reader, impl Future<Output = Result<Option<FeMessage>, ProtocolError>> + '_>
|
||||
where
|
||||
Reader: tokio::io::AsyncRead + Unpin,
|
||||
{
|
||||
use tokio::io::AsyncReadExt;
|
||||
|
||||
const MAX_STARTUP_PACKET_LENGTH: usize = 10000;
|
||||
const RESERVED_INVALID_MAJOR_VERSION: u32 = 1234;
|
||||
const CANCEL_REQUEST_CODE: u32 = 5678;
|
||||
@@ -302,31 +460,43 @@ impl FeStartupPacket {
|
||||
let len = match retry_read!(stream.read_u32().await) {
|
||||
Ok(len) => len as usize,
|
||||
Err(e) if e.kind() == io::ErrorKind::UnexpectedEof => return Ok(None),
|
||||
Err(e) => return Err(e.into()),
|
||||
Err(e) => return Err(ProtocolError::Socket(e)),
|
||||
};
|
||||
|
||||
#[allow(clippy::manual_range_contains)]
|
||||
if len < 4 || len > MAX_STARTUP_PACKET_LENGTH {
|
||||
bail!("invalid message length");
|
||||
return Err(ProtocolError::Protocol(format!(
|
||||
"invalid message length {len}"
|
||||
)));
|
||||
}
|
||||
|
||||
let request_code = retry_read!(stream.read_u32().await)?;
|
||||
let request_code =
|
||||
retry_read!(stream.read_u32().await).map_err(ProtocolError::Socket)?;
|
||||
|
||||
// the rest of startup packet are params
|
||||
let params_len = len - 8;
|
||||
let mut params_bytes = vec![0u8; params_len];
|
||||
stream.read_exact(params_bytes.as_mut()).await?;
|
||||
stream
|
||||
.read_exact(params_bytes.as_mut())
|
||||
.await
|
||||
.map_err(ProtocolError::Socket)?;
|
||||
|
||||
// Parse params depending on request code
|
||||
let req_hi = request_code >> 16;
|
||||
let req_lo = request_code & ((1 << 16) - 1);
|
||||
let message = match (req_hi, req_lo) {
|
||||
(RESERVED_INVALID_MAJOR_VERSION, CANCEL_REQUEST_CODE) => {
|
||||
ensure!(params_len == 8, "expected 8 bytes for CancelRequest params");
|
||||
if params_len != 8 {
|
||||
return Err(ProtocolError::Protocol(
|
||||
"expected 8 bytes for CancelRequest params".to_string(),
|
||||
));
|
||||
}
|
||||
let mut cursor = Cursor::new(params_bytes);
|
||||
FeStartupPacket::CancelRequest(CancelKeyData {
|
||||
backend_pid: cursor.read_i32().await?,
|
||||
cancel_key: cursor.read_i32().await?,
|
||||
backend_pid: 2,
|
||||
cancel_key: 2,
|
||||
// backend_pid: cursor.read_i32().await.map_err(ConnectionError::Socket)?,
|
||||
// cancel_key: cursor.read_i32().await.map_err(ConnectionError::Socket)?,
|
||||
})
|
||||
}
|
||||
(RESERVED_INVALID_MAJOR_VERSION, NEGOTIATE_SSL_CODE) => {
|
||||
@@ -338,7 +508,9 @@ impl FeStartupPacket {
|
||||
FeStartupPacket::GssEncRequest
|
||||
}
|
||||
(RESERVED_INVALID_MAJOR_VERSION, unrecognized_code) => {
|
||||
bail!("Unrecognized request code {}", unrecognized_code)
|
||||
return Err(ProtocolError::Protocol(format!(
|
||||
"Unrecognized request code {unrecognized_code}"
|
||||
)));
|
||||
}
|
||||
// TODO bail if protocol major_version is not 3?
|
||||
(major_version, minor_version) => {
|
||||
@@ -346,15 +518,21 @@ impl FeStartupPacket {
|
||||
// See `postgres: ProcessStartupPacket, build_startup_packet`.
|
||||
let mut tokens = str::from_utf8(¶ms_bytes)
|
||||
.context("StartupMessage params: invalid utf-8")?
|
||||
.strip_suffix('\0') // drop packet's own null terminator
|
||||
.context("StartupMessage params: missing null terminator")?
|
||||
.strip_suffix('\0') // drop packet's own null
|
||||
.ok_or_else(|| {
|
||||
ProtocolError::Protocol(
|
||||
"StartupMessage params: missing null terminator".to_string(),
|
||||
)
|
||||
})?
|
||||
.split_terminator('\0');
|
||||
|
||||
let mut params = HashMap::new();
|
||||
while let Some(name) = tokens.next() {
|
||||
let value = tokens
|
||||
.next()
|
||||
.context("StartupMessage params: key without value")?;
|
||||
let value = tokens.next().ok_or_else(|| {
|
||||
ProtocolError::Protocol(
|
||||
"StartupMessage params: key without value".to_string(),
|
||||
)
|
||||
})?;
|
||||
|
||||
params.insert(name.to_owned(), value.to_owned());
|
||||
}
|
||||
@@ -381,6 +559,9 @@ impl FeParseMessage {
|
||||
|
||||
let _pstmt_name = read_cstr(&mut buf)?;
|
||||
let query_string = read_cstr(&mut buf)?;
|
||||
if buf.remaining() < 2 {
|
||||
bail!("Parse message is malformed, nparams missing");
|
||||
}
|
||||
let nparams = buf.get_i16();
|
||||
|
||||
ensure!(nparams == 0, "query params not implemented");
|
||||
@@ -407,6 +588,9 @@ impl FeDescribeMessage {
|
||||
impl FeExecuteMessage {
|
||||
fn parse(mut buf: Bytes) -> anyhow::Result<FeMessage> {
|
||||
let portal_name = read_cstr(&mut buf)?;
|
||||
if buf.remaining() < 4 {
|
||||
bail!("FeExecuteMessage message is malformed, maxrows missing");
|
||||
}
|
||||
let maxrows = buf.get_i32();
|
||||
|
||||
ensure!(portal_name.is_empty(), "named portals not implemented");
|
||||
@@ -458,7 +642,7 @@ pub enum BeMessage<'a> {
|
||||
CloseComplete,
|
||||
// None means column is NULL
|
||||
DataRow(&'a [Option<&'a [u8]>]),
|
||||
ErrorResponse(&'a str),
|
||||
ErrorResponse(&'a str, Option<&'a [u8; 5]>),
|
||||
/// Single byte - used in response to SSLRequest/GSSENCRequest.
|
||||
EncryptionResponse(bool),
|
||||
NoData,
|
||||
@@ -488,6 +672,11 @@ impl<'a> BeMessage<'a> {
|
||||
value: b"UTF8",
|
||||
};
|
||||
|
||||
pub const INTEGER_DATETIMES: Self = Self::ParameterStatus {
|
||||
name: b"integer_datetimes",
|
||||
value: b"on",
|
||||
};
|
||||
|
||||
/// Build a [`BeMessage::ParameterStatus`] holding the server version.
|
||||
pub fn server_version(version: &'a str) -> Self {
|
||||
Self::ParameterStatus {
|
||||
@@ -606,13 +795,12 @@ fn write_body<R>(buf: &mut BytesMut, f: impl FnOnce(&mut BytesMut) -> R) -> R {
|
||||
}
|
||||
|
||||
/// Safe write of s into buf as cstring (String in the protocol).
|
||||
fn write_cstr(s: impl AsRef<[u8]>, buf: &mut BytesMut) -> Result<(), io::Error> {
|
||||
fn write_cstr(s: impl AsRef<[u8]>, buf: &mut BytesMut) -> Result<(), ProtocolError> {
|
||||
let bytes = s.as_ref();
|
||||
if bytes.contains(&0) {
|
||||
return Err(io::Error::new(
|
||||
io::ErrorKind::InvalidInput,
|
||||
"string contains embedded null",
|
||||
));
|
||||
return Err(ProtocolError::MessageParse(anyhow!(
|
||||
"string contains embedded null"
|
||||
)));
|
||||
}
|
||||
buf.put_slice(bytes);
|
||||
buf.put_u8(0);
|
||||
@@ -621,20 +809,20 @@ fn write_cstr(s: impl AsRef<[u8]>, buf: &mut BytesMut) -> Result<(), io::Error>
|
||||
|
||||
fn read_cstr(buf: &mut Bytes) -> anyhow::Result<Bytes> {
|
||||
let pos = buf.iter().position(|x| *x == 0);
|
||||
let result = buf.split_to(pos.context("missing terminator")?);
|
||||
let result = buf.split_to(pos.context("missing cstring terminator")?);
|
||||
buf.advance(1); // drop the null terminator
|
||||
Ok(result)
|
||||
}
|
||||
|
||||
const SQLSTATE_INTERNAL_ERROR: &str = "XX000\0";
|
||||
pub const SQLSTATE_INTERNAL_ERROR: &[u8; 5] = b"XX000";
|
||||
|
||||
impl<'a> BeMessage<'a> {
|
||||
/// Write message to the given buf.
|
||||
// Unlike the reading side, we use BytesMut
|
||||
// here as msg len precedes its body and it is handy to write it down first
|
||||
// and then fill the length. With Write we would have to either calc it
|
||||
// manually or have one more buffer.
|
||||
pub fn write(buf: &mut BytesMut, message: &BeMessage) -> io::Result<()> {
|
||||
/// Serialize `message` to the given `buf`.
|
||||
/// Apart from smart memory managemet, BytesMut is good here as msg len
|
||||
/// precedes its body and it is handy to write it down first and then fill
|
||||
/// the length. With Write we would have to either calc it manually or have
|
||||
/// one more buffer.
|
||||
pub fn write(buf: &mut BytesMut, message: &BeMessage) -> Result<(), ProtocolError> {
|
||||
match message {
|
||||
BeMessage::AuthenticationOk => {
|
||||
buf.put_u8(b'R');
|
||||
@@ -660,7 +848,7 @@ impl<'a> BeMessage<'a> {
|
||||
|
||||
BeMessage::AuthenticationSasl(msg) => {
|
||||
buf.put_u8(b'R');
|
||||
write_body(buf, |buf| {
|
||||
write_body(buf, |buf| -> Result<(), ProtocolError> {
|
||||
use BeAuthenticationSaslMessage::*;
|
||||
match msg {
|
||||
Methods(methods) => {
|
||||
@@ -679,7 +867,7 @@ impl<'a> BeMessage<'a> {
|
||||
buf.put_slice(extra);
|
||||
}
|
||||
}
|
||||
Ok::<_, io::Error>(())
|
||||
Ok(())
|
||||
})?;
|
||||
}
|
||||
|
||||
@@ -767,24 +955,23 @@ impl<'a> BeMessage<'a> {
|
||||
// First byte of each field represents type of this field. Set just enough fields
|
||||
// to satisfy rust-postgres client: 'S' -- severity, 'C' -- error, 'M' -- error
|
||||
// message text.
|
||||
BeMessage::ErrorResponse(error_msg) => {
|
||||
// For all the errors set Severity to Error and error code to
|
||||
// 'internal error'.
|
||||
|
||||
BeMessage::ErrorResponse(error_msg, pg_error_code) => {
|
||||
// 'E' signalizes ErrorResponse messages
|
||||
buf.put_u8(b'E');
|
||||
write_body(buf, |buf| {
|
||||
write_body(buf, |buf| -> Result<(), ProtocolError> {
|
||||
buf.put_u8(b'S'); // severity
|
||||
buf.put_slice(b"ERROR\0");
|
||||
|
||||
buf.put_u8(b'C'); // SQLSTATE error code
|
||||
buf.put_slice(SQLSTATE_INTERNAL_ERROR.as_bytes());
|
||||
buf.put_slice(&terminate_code(
|
||||
pg_error_code.unwrap_or(SQLSTATE_INTERNAL_ERROR),
|
||||
));
|
||||
|
||||
buf.put_u8(b'M'); // the message
|
||||
write_cstr(error_msg, buf)?;
|
||||
|
||||
buf.put_u8(0); // terminator
|
||||
Ok::<_, io::Error>(())
|
||||
Ok(())
|
||||
})?;
|
||||
}
|
||||
|
||||
@@ -796,18 +983,18 @@ impl<'a> BeMessage<'a> {
|
||||
|
||||
// 'N' signalizes NoticeResponse messages
|
||||
buf.put_u8(b'N');
|
||||
write_body(buf, |buf| {
|
||||
write_body(buf, |buf| -> Result<(), ProtocolError> {
|
||||
buf.put_u8(b'S'); // severity
|
||||
buf.put_slice(b"NOTICE\0");
|
||||
|
||||
buf.put_u8(b'C'); // SQLSTATE error code
|
||||
buf.put_slice(SQLSTATE_INTERNAL_ERROR.as_bytes());
|
||||
buf.put_slice(&terminate_code(SQLSTATE_INTERNAL_ERROR));
|
||||
|
||||
buf.put_u8(b'M'); // the message
|
||||
write_cstr(error_msg.as_bytes(), buf)?;
|
||||
|
||||
buf.put_u8(0); // terminator
|
||||
Ok::<_, io::Error>(())
|
||||
Ok(())
|
||||
})?;
|
||||
}
|
||||
|
||||
@@ -851,7 +1038,7 @@ impl<'a> BeMessage<'a> {
|
||||
|
||||
BeMessage::RowDescription(rows) => {
|
||||
buf.put_u8(b'T');
|
||||
write_body(buf, |buf| {
|
||||
write_body(buf, |buf| -> Result<(), ProtocolError> {
|
||||
buf.put_i16(rows.len() as i16); // # of fields
|
||||
for row in rows.iter() {
|
||||
write_cstr(row.name, buf)?;
|
||||
@@ -862,7 +1049,7 @@ impl<'a> BeMessage<'a> {
|
||||
buf.put_i32(-1); /* typmod */
|
||||
buf.put_i16(0); /* format code */
|
||||
}
|
||||
Ok::<_, io::Error>(())
|
||||
Ok(())
|
||||
})?;
|
||||
}
|
||||
|
||||
@@ -1089,3 +1276,12 @@ mod tests {
|
||||
let _ = FeStartupPacket::read_fut(stream).await;
|
||||
}
|
||||
}
|
||||
|
||||
fn terminate_code(code: &[u8; 5]) -> [u8; 6] {
|
||||
let mut terminated = [0; 6];
|
||||
for (i, &elem) in code.iter().enumerate() {
|
||||
terminated[i] = elem;
|
||||
}
|
||||
|
||||
terminated
|
||||
}
|
||||
|
||||
@@ -111,7 +111,7 @@ pub trait RemoteStorage: Send + Sync + 'static {
|
||||
}
|
||||
|
||||
pub struct Download {
|
||||
pub download_stream: Pin<Box<dyn io::AsyncRead + Unpin + Send>>,
|
||||
pub download_stream: Pin<Box<dyn io::AsyncRead + Unpin + Send + Sync>>,
|
||||
/// Extra key-value data, associated with the current remote file.
|
||||
pub metadata: Option<StorageMetadata>,
|
||||
}
|
||||
|
||||
@@ -10,13 +10,16 @@ async-trait = "0.1"
|
||||
anyhow = "1.0"
|
||||
bincode = "1.3"
|
||||
bytes = "1.0.1"
|
||||
futures = "0.3"
|
||||
hyper = { version = "0.14.7", features = ["full"] }
|
||||
pin-utils = "0.1"
|
||||
routerify = "3"
|
||||
serde = { version = "1.0", features = ["derive"] }
|
||||
serde_json = "1"
|
||||
thiserror = "1.0"
|
||||
tokio = { version = "1.17", features = ["macros"]}
|
||||
tokio-rustls = "0.23"
|
||||
tokio-util = { version = "0.7.3" }
|
||||
tracing = "0.1"
|
||||
tracing-subscriber = { version = "0.3", features = ["env-filter", "json"] }
|
||||
nix = "0.25"
|
||||
|
||||
@@ -13,7 +13,7 @@ pub mod simple_rcu;
|
||||
pub mod vec_map;
|
||||
|
||||
pub mod bin_ser;
|
||||
pub mod postgres_backend;
|
||||
// pub mod postgres_backend;
|
||||
pub mod postgres_backend_async;
|
||||
|
||||
// helper functions for creating and fsyncing
|
||||
@@ -52,6 +52,8 @@ pub mod signals;
|
||||
|
||||
pub mod fs_ext;
|
||||
|
||||
pub mod send_rc;
|
||||
|
||||
/// use with fail::cfg("$name", "return(2000)")
|
||||
#[macro_export]
|
||||
macro_rules! failpoint_sleep_millis_async {
|
||||
|
||||
@@ -3,8 +3,9 @@
|
||||
//! implementation determining how to process the queries. Currently its API
|
||||
//! is rather narrow, but we can extend it once required.
|
||||
|
||||
use crate::postgres_backend_async::{log_query_error, short_error, QueryError};
|
||||
use crate::sock_split::{BidiStream, ReadStream, WriteStream};
|
||||
use anyhow::{bail, ensure, Context, Result};
|
||||
use anyhow::Context;
|
||||
use bytes::{Bytes, BytesMut};
|
||||
use pq_proto::{BeMessage, FeMessage, FeStartupPacket};
|
||||
use serde::{Deserialize, Serialize};
|
||||
@@ -21,20 +22,32 @@ pub trait Handler {
|
||||
/// postgres_backend will issue ReadyForQuery after calling this (this
|
||||
/// might be not what we want after CopyData streaming, but currently we don't
|
||||
/// care).
|
||||
fn process_query(&mut self, pgb: &mut PostgresBackend, query_string: &str) -> Result<()>;
|
||||
fn process_query(
|
||||
&mut self,
|
||||
pgb: &mut PostgresBackend,
|
||||
query_string: &str,
|
||||
) -> Result<(), QueryError>;
|
||||
|
||||
/// Called on startup packet receival, allows to process params.
|
||||
///
|
||||
/// If Ok(false) is returned postgres_backend will skip auth -- that is needed for new users
|
||||
/// creation is the proxy code. That is quite hacky and ad-hoc solution, may be we could allow
|
||||
/// to override whole init logic in implementations.
|
||||
fn startup(&mut self, _pgb: &mut PostgresBackend, _sm: &FeStartupPacket) -> Result<()> {
|
||||
fn startup(
|
||||
&mut self,
|
||||
_pgb: &mut PostgresBackend,
|
||||
_sm: &FeStartupPacket,
|
||||
) -> Result<(), QueryError> {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Check auth jwt
|
||||
fn check_auth_jwt(&mut self, _pgb: &mut PostgresBackend, _jwt_response: &[u8]) -> Result<()> {
|
||||
bail!("JWT auth failed")
|
||||
fn check_auth_jwt(
|
||||
&mut self,
|
||||
_pgb: &mut PostgresBackend,
|
||||
_jwt_response: &[u8],
|
||||
) -> Result<(), QueryError> {
|
||||
Err(QueryError::Other(anyhow::anyhow!("JWT auth failed")))
|
||||
}
|
||||
|
||||
fn is_shutdown_requested(&self) -> bool {
|
||||
@@ -66,7 +79,7 @@ impl FromStr for AuthType {
|
||||
match s {
|
||||
"Trust" => Ok(Self::Trust),
|
||||
"NeonJWT" => Ok(Self::NeonJWT),
|
||||
_ => bail!("invalid value \"{s}\" for auth type"),
|
||||
_ => anyhow::bail!("invalid value \"{s}\" for auth type"),
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -154,7 +167,7 @@ pub fn is_socket_read_timed_out(error: &anyhow::Error) -> bool {
|
||||
}
|
||||
|
||||
// Cast a byte slice to a string slice, dropping null terminator if there's one.
|
||||
fn cstr_to_str(bytes: &[u8]) -> Result<&str> {
|
||||
fn cstr_to_str(bytes: &[u8]) -> anyhow::Result<&str> {
|
||||
let without_null = bytes.strip_suffix(&[0]).unwrap_or(bytes);
|
||||
std::str::from_utf8(without_null).map_err(|e| e.into())
|
||||
}
|
||||
@@ -188,10 +201,10 @@ impl PostgresBackend {
|
||||
}
|
||||
|
||||
/// Get direct reference (into the Option) to the read stream.
|
||||
fn get_stream_in(&mut self) -> Result<&mut BidiStream> {
|
||||
fn get_stream_in(&mut self) -> anyhow::Result<&mut BidiStream> {
|
||||
match &mut self.stream {
|
||||
Some(Stream::Bidirectional(stream)) => Ok(stream),
|
||||
_ => bail!("reader taken"),
|
||||
_ => anyhow::bail!("reader taken"),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -215,7 +228,7 @@ impl PostgresBackend {
|
||||
}
|
||||
|
||||
/// Read full message or return None if connection is closed.
|
||||
pub fn read_message(&mut self) -> Result<Option<FeMessage>> {
|
||||
pub fn read_message(&mut self) -> Result<Option<FeMessage>, QueryError> {
|
||||
let (state, stream) = (self.state, self.get_stream_in()?);
|
||||
|
||||
use ProtoState::*;
|
||||
@@ -223,6 +236,7 @@ impl PostgresBackend {
|
||||
Initialization | Encrypted => FeStartupPacket::read(stream),
|
||||
Authentication | Established => FeMessage::read(stream),
|
||||
}
|
||||
.map_err(QueryError::from)
|
||||
}
|
||||
|
||||
/// Write message into internal output buffer.
|
||||
@@ -246,7 +260,7 @@ impl PostgresBackend {
|
||||
}
|
||||
|
||||
// Wrapper for run_message_loop() that shuts down socket when we are done
|
||||
pub fn run(mut self, handler: &mut impl Handler) -> Result<()> {
|
||||
pub fn run(mut self, handler: &mut impl Handler) -> Result<(), QueryError> {
|
||||
let ret = self.run_message_loop(handler);
|
||||
if let Some(stream) = self.stream.as_mut() {
|
||||
let _ = stream.shutdown(Shutdown::Both);
|
||||
@@ -254,7 +268,7 @@ impl PostgresBackend {
|
||||
ret
|
||||
}
|
||||
|
||||
fn run_message_loop(&mut self, handler: &mut impl Handler) -> Result<()> {
|
||||
fn run_message_loop(&mut self, handler: &mut impl Handler) -> Result<(), QueryError> {
|
||||
trace!("postgres backend to {:?} started", self.peer_addr);
|
||||
|
||||
let mut unnamed_query_string = Bytes::new();
|
||||
@@ -263,7 +277,7 @@ impl PostgresBackend {
|
||||
match self.read_message() {
|
||||
Ok(message) => {
|
||||
if let Some(msg) = message {
|
||||
trace!("got message {:?}", msg);
|
||||
trace!("got message {msg:?}");
|
||||
|
||||
match self.process_message(handler, msg, &mut unnamed_query_string)? {
|
||||
ProcessMsgResult::Continue => continue,
|
||||
@@ -274,10 +288,12 @@ impl PostgresBackend {
|
||||
}
|
||||
}
|
||||
Err(e) => {
|
||||
// If it is a timeout error, continue the loop
|
||||
if !is_socket_read_timed_out(&e) {
|
||||
return Err(e);
|
||||
if let QueryError::Other(e) = &e {
|
||||
if is_socket_read_timed_out(e) {
|
||||
continue;
|
||||
}
|
||||
}
|
||||
return Err(e);
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -295,7 +311,7 @@ impl PostgresBackend {
|
||||
}
|
||||
stream => {
|
||||
self.stream = stream;
|
||||
bail!("can't start TLs without bidi stream");
|
||||
anyhow::bail!("can't start TLs without bidi stream");
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -305,17 +321,16 @@ impl PostgresBackend {
|
||||
handler: &mut impl Handler,
|
||||
msg: FeMessage,
|
||||
unnamed_query_string: &mut Bytes,
|
||||
) -> Result<ProcessMsgResult> {
|
||||
) -> Result<ProcessMsgResult, QueryError> {
|
||||
// Allow only startup and password messages during auth. Otherwise client would be able to bypass auth
|
||||
// TODO: change that to proper top-level match of protocol state with separate message handling for each state
|
||||
if self.state < ProtoState::Established {
|
||||
ensure!(
|
||||
matches!(
|
||||
msg,
|
||||
FeMessage::PasswordMessage(_) | FeMessage::StartupPacket(_)
|
||||
),
|
||||
"protocol violation"
|
||||
);
|
||||
if self.state < ProtoState::Established
|
||||
&& !matches!(
|
||||
msg,
|
||||
FeMessage::PasswordMessage(_) | FeMessage::StartupPacket(_)
|
||||
)
|
||||
{
|
||||
return Err(QueryError::Other(anyhow::anyhow!("protocol violation")));
|
||||
}
|
||||
|
||||
let have_tls = self.tls_config.is_some();
|
||||
@@ -339,8 +354,13 @@ impl PostgresBackend {
|
||||
}
|
||||
FeStartupPacket::StartupMessage { .. } => {
|
||||
if have_tls && !matches!(self.state, ProtoState::Encrypted) {
|
||||
self.write_message(&BeMessage::ErrorResponse("must connect with TLS"))?;
|
||||
bail!("client did not connect with TLS");
|
||||
self.write_message(&BeMessage::ErrorResponse(
|
||||
"must connect with TLS",
|
||||
None,
|
||||
))?;
|
||||
return Err(QueryError::Other(anyhow::anyhow!(
|
||||
"client did not connect with TLS"
|
||||
)));
|
||||
}
|
||||
|
||||
// NB: startup() may change self.auth_type -- we are using that in proxy code
|
||||
@@ -379,8 +399,11 @@ impl PostgresBackend {
|
||||
let (_, jwt_response) = m.split_last().context("protocol violation")?;
|
||||
|
||||
if let Err(e) = handler.check_auth_jwt(self, jwt_response) {
|
||||
self.write_message(&BeMessage::ErrorResponse(&e.to_string()))?;
|
||||
bail!("auth failed: {}", e);
|
||||
self.write_message(&BeMessage::ErrorResponse(
|
||||
&e.to_string(),
|
||||
Some(e.pg_error_code()),
|
||||
))?;
|
||||
return Err(e);
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -394,33 +417,14 @@ impl PostgresBackend {
|
||||
// remove null terminator
|
||||
let query_string = cstr_to_str(&body)?;
|
||||
|
||||
trace!("got query {:?}", query_string);
|
||||
// xxx distinguish fatal and recoverable errors?
|
||||
trace!("got query {query_string:?}");
|
||||
if let Err(e) = handler.process_query(self, query_string) {
|
||||
// ":?" uses the alternate formatting style, which makes anyhow display the
|
||||
// full cause of the error, not just the top-level context + its trace.
|
||||
// We don't want to send that in the ErrorResponse though,
|
||||
// because it's not relevant to the compute node logs.
|
||||
//
|
||||
// We also don't want to log full stacktrace when the error is primitive,
|
||||
// such as usual connection closed.
|
||||
let short_error = format!("{:#}", e);
|
||||
let root_cause = e.root_cause().to_string();
|
||||
if root_cause.contains("connection closed unexpectedly")
|
||||
|| root_cause.contains("Broken pipe (os error 32)")
|
||||
{
|
||||
error!(
|
||||
"query handler for '{}' failed: {}",
|
||||
query_string, short_error
|
||||
);
|
||||
} else {
|
||||
error!("query handler for '{}' failed: {:?}", query_string, e);
|
||||
}
|
||||
self.write_message_noflush(&BeMessage::ErrorResponse(&short_error))?;
|
||||
// TODO: untangle convoluted control flow
|
||||
if e.to_string().contains("failed to run") {
|
||||
return Ok(ProcessMsgResult::Break);
|
||||
}
|
||||
log_query_error(query_string, &e);
|
||||
let short_error = short_error(&e);
|
||||
self.write_message_noflush(&BeMessage::ErrorResponse(
|
||||
&short_error,
|
||||
Some(e.pg_error_code()),
|
||||
))?;
|
||||
}
|
||||
self.write_message(&BeMessage::ReadyForQuery)?;
|
||||
}
|
||||
@@ -445,11 +449,13 @@ impl PostgresBackend {
|
||||
|
||||
FeMessage::Execute(_) => {
|
||||
let query_string = cstr_to_str(unnamed_query_string)?;
|
||||
trace!("got execute {:?}", query_string);
|
||||
// xxx distinguish fatal and recoverable errors?
|
||||
trace!("got execute {query_string:?}");
|
||||
if let Err(e) = handler.process_query(self, query_string) {
|
||||
error!("query handler for '{}' failed: {:?}", query_string, e);
|
||||
self.write_message(&BeMessage::ErrorResponse(&e.to_string()))?;
|
||||
log_query_error(query_string, &e);
|
||||
self.write_message(&BeMessage::ErrorResponse(
|
||||
&e.to_string(),
|
||||
Some(e.pg_error_code()),
|
||||
))?;
|
||||
}
|
||||
// NOTE there is no ReadyForQuery message. This handler is used
|
||||
// for basebackup and it uses CopyOut which doesn't require
|
||||
@@ -468,7 +474,9 @@ impl PostgresBackend {
|
||||
// We prefer explicit pattern matching to wildcards, because
|
||||
// this helps us spot the places where new variants are missing
|
||||
FeMessage::CopyData(_) | FeMessage::CopyDone | FeMessage::CopyFail => {
|
||||
bail!("unexpected message type: {:?}", msg);
|
||||
return Err(QueryError::Other(anyhow::anyhow!(
|
||||
"unexpected message type: {msg:?}"
|
||||
)));
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -2,20 +2,59 @@
|
||||
//! To use, create PostgresBackend and run() it, passing the Handler
|
||||
//! implementation determining how to process the queries. Currently its API
|
||||
//! is rather narrow, but we can extend it once required.
|
||||
|
||||
use crate::postgres_backend::AuthType;
|
||||
use anyhow::{bail, Context, Result};
|
||||
use anyhow::Context;
|
||||
use bytes::{Buf, Bytes, BytesMut};
|
||||
use pq_proto::{BeMessage, FeMessage, FeStartupPacket};
|
||||
use std::future::Future;
|
||||
use futures::stream::StreamExt;
|
||||
use futures::{pin_mut, Sink, SinkExt};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::net::SocketAddr;
|
||||
use std::pin::Pin;
|
||||
use std::sync::Arc;
|
||||
use std::task::Poll;
|
||||
use tracing::{debug, error, trace};
|
||||
|
||||
use std::{fmt, io};
|
||||
use std::{future::Future, str::FromStr};
|
||||
use tokio::io::{AsyncRead, AsyncWrite, AsyncWriteExt, BufReader};
|
||||
use tokio_rustls::TlsAcceptor;
|
||||
use tokio_util::codec::Framed;
|
||||
use tracing::{debug, error, info, trace};
|
||||
|
||||
use pq_proto::codec::{ConnectionError, PostgresCodec};
|
||||
use pq_proto::{BeMessage, FeMessage, FeStartupPacket, SQLSTATE_INTERNAL_ERROR};
|
||||
|
||||
/// An error, occurred during query processing:
|
||||
/// either during the connection ([`ConnectionError`]) or before/after it.
|
||||
#[derive(thiserror::Error, Debug)]
|
||||
pub enum QueryError {
|
||||
/// The connection was lost while processing the query.
|
||||
#[error(transparent)]
|
||||
Disconnected(#[from] ConnectionError),
|
||||
/// Some other error
|
||||
#[error(transparent)]
|
||||
Other(#[from] anyhow::Error),
|
||||
}
|
||||
|
||||
impl From<io::Error> for QueryError {
|
||||
fn from(e: io::Error) -> Self {
|
||||
Self::Disconnected(ConnectionError::Io(e))
|
||||
}
|
||||
}
|
||||
|
||||
impl QueryError {
|
||||
pub fn pg_error_code(&self) -> &'static [u8; 5] {
|
||||
match self {
|
||||
Self::Disconnected(_) => b"08006", // connection failure
|
||||
Self::Other(_) => SQLSTATE_INTERNAL_ERROR, // internal error
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub fn is_expected_io_error(e: &io::Error) -> bool {
|
||||
use io::ErrorKind::*;
|
||||
matches!(
|
||||
e.kind(),
|
||||
ConnectionRefused | ConnectionAborted | ConnectionReset
|
||||
)
|
||||
}
|
||||
|
||||
#[async_trait::async_trait]
|
||||
pub trait Handler {
|
||||
@@ -23,20 +62,32 @@ pub trait Handler {
|
||||
/// postgres_backend will issue ReadyForQuery after calling this (this
|
||||
/// might be not what we want after CopyData streaming, but currently we don't
|
||||
/// care).
|
||||
async fn process_query(&mut self, pgb: &mut PostgresBackend, query_string: &str) -> Result<()>;
|
||||
async fn process_query(
|
||||
&mut self,
|
||||
pgb: &mut PostgresBackend,
|
||||
query_string: &str,
|
||||
) -> Result<(), QueryError>;
|
||||
|
||||
/// Called on startup packet receival, allows to process params.
|
||||
///
|
||||
/// If Ok(false) is returned postgres_backend will skip auth -- that is needed for new users
|
||||
/// creation is the proxy code. That is quite hacky and ad-hoc solution, may be we could allow
|
||||
/// to override whole init logic in implementations.
|
||||
fn startup(&mut self, _pgb: &mut PostgresBackend, _sm: &FeStartupPacket) -> Result<()> {
|
||||
fn startup(
|
||||
&mut self,
|
||||
_pgb: &mut PostgresBackend,
|
||||
_sm: &FeStartupPacket,
|
||||
) -> Result<(), QueryError> {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Check auth jwt
|
||||
fn check_auth_jwt(&mut self, _pgb: &mut PostgresBackend, _jwt_response: &[u8]) -> Result<()> {
|
||||
bail!("JWT auth failed")
|
||||
fn check_auth_jwt(
|
||||
&mut self,
|
||||
_pgb: &mut PostgresBackend,
|
||||
_jwt_response: &[u8],
|
||||
) -> Result<(), QueryError> {
|
||||
Err(QueryError::Other(anyhow::anyhow!("JWT auth failed")))
|
||||
}
|
||||
}
|
||||
|
||||
@@ -45,6 +96,7 @@ pub trait Handler {
|
||||
#[derive(Clone, Copy, PartialEq, Eq, PartialOrd)]
|
||||
pub enum ProtoState {
|
||||
Initialization,
|
||||
// Encryption handshake is done; waiting for encrypted Startup message.
|
||||
Encrypted,
|
||||
Authentication,
|
||||
Established,
|
||||
@@ -57,68 +109,95 @@ pub enum ProcessMsgResult {
|
||||
Break,
|
||||
}
|
||||
|
||||
/// Always-writeable sock_split stream.
|
||||
/// May not be readable. See [`PostgresBackend::take_stream_in`]
|
||||
pub enum Stream {
|
||||
Unencrypted(BufReader<tokio::net::TcpStream>),
|
||||
Tls(Box<tokio_rustls::server::TlsStream<BufReader<tokio::net::TcpStream>>>),
|
||||
Broken,
|
||||
/// Either plain TCP stream or encrypted one, implementing AsyncRead + AsyncWrite.
|
||||
pub enum MaybeTlsStream {
|
||||
Unencrypted(tokio::net::TcpStream),
|
||||
Tls(Box<tokio_rustls::server::TlsStream<tokio::net::TcpStream>>),
|
||||
Broken, // temporary value for switch to TLS
|
||||
}
|
||||
|
||||
impl AsyncWrite for Stream {
|
||||
impl AsyncWrite for MaybeTlsStream {
|
||||
fn poll_write(
|
||||
self: Pin<&mut Self>,
|
||||
cx: &mut std::task::Context<'_>,
|
||||
buf: &[u8],
|
||||
) -> Poll<Result<usize, std::io::Error>> {
|
||||
) -> Poll<io::Result<usize>> {
|
||||
match self.get_mut() {
|
||||
Self::Unencrypted(stream) => Pin::new(stream).poll_write(cx, buf),
|
||||
Self::Tls(stream) => Pin::new(stream).poll_write(cx, buf),
|
||||
Self::Broken => unreachable!(),
|
||||
_ => unreachable!(),
|
||||
}
|
||||
}
|
||||
fn poll_flush(
|
||||
self: Pin<&mut Self>,
|
||||
cx: &mut std::task::Context<'_>,
|
||||
) -> Poll<Result<(), std::io::Error>> {
|
||||
fn poll_flush(self: Pin<&mut Self>, cx: &mut std::task::Context<'_>) -> Poll<io::Result<()>> {
|
||||
match self.get_mut() {
|
||||
Self::Unencrypted(stream) => Pin::new(stream).poll_flush(cx),
|
||||
Self::Tls(stream) => Pin::new(stream).poll_flush(cx),
|
||||
Self::Broken => unreachable!(),
|
||||
_ => unreachable!(),
|
||||
}
|
||||
}
|
||||
fn poll_shutdown(
|
||||
self: Pin<&mut Self>,
|
||||
cx: &mut std::task::Context<'_>,
|
||||
) -> Poll<Result<(), std::io::Error>> {
|
||||
) -> Poll<io::Result<()>> {
|
||||
match self.get_mut() {
|
||||
Self::Unencrypted(stream) => Pin::new(stream).poll_shutdown(cx),
|
||||
Self::Tls(stream) => Pin::new(stream).poll_shutdown(cx),
|
||||
Self::Broken => unreachable!(),
|
||||
_ => unreachable!(),
|
||||
}
|
||||
}
|
||||
}
|
||||
impl AsyncRead for Stream {
|
||||
impl AsyncRead for MaybeTlsStream {
|
||||
fn poll_read(
|
||||
self: Pin<&mut Self>,
|
||||
cx: &mut std::task::Context<'_>,
|
||||
buf: &mut tokio::io::ReadBuf<'_>,
|
||||
) -> Poll<Result<(), std::io::Error>> {
|
||||
) -> Poll<io::Result<()>> {
|
||||
match self.get_mut() {
|
||||
Self::Unencrypted(stream) => Pin::new(stream).poll_read(cx, buf),
|
||||
Self::Tls(stream) => Pin::new(stream).poll_read(cx, buf),
|
||||
Self::Broken => unreachable!(),
|
||||
_ => unreachable!(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub struct PostgresBackend {
|
||||
stream: Stream,
|
||||
#[derive(Debug, PartialEq, Eq, Clone, Copy, Serialize, Deserialize)]
|
||||
pub enum AuthType {
|
||||
Trust,
|
||||
// This mimics postgres's AuthenticationCleartextPassword but instead of password expects JWT
|
||||
NeonJWT,
|
||||
}
|
||||
|
||||
// Output buffer. c.f. BeMessage::write why we are using BytesMut here.
|
||||
// The data between 0 and "current position" as tracked by the bytes::Buf
|
||||
// implementation of BytesMut, have already been written.
|
||||
buf_out: BytesMut,
|
||||
impl FromStr for AuthType {
|
||||
type Err = anyhow::Error;
|
||||
|
||||
fn from_str(s: &str) -> Result<Self, Self::Err> {
|
||||
match s {
|
||||
"Trust" => Ok(Self::Trust),
|
||||
"NeonJWT" => Ok(Self::NeonJWT),
|
||||
_ => anyhow::bail!("invalid value \"{s}\" for auth type"),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl fmt::Display for AuthType {
|
||||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||
f.write_str(match self {
|
||||
AuthType::Trust => "Trust",
|
||||
AuthType::NeonJWT => "NeonJWT",
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
pub struct PostgresBackend {
|
||||
// Provides serialization/deserialization to the underlying transport backed
|
||||
// with buffers; implements Sink consuming messages and Stream reading them.
|
||||
//
|
||||
// Sink::start_send only queues message to the interal buffer.
|
||||
// SinkExt::flush flushes buffer to the stream.
|
||||
//
|
||||
// StreamExt::read reads next message. In case of EOF without partial
|
||||
// message it returns None.
|
||||
stream: Framed<MaybeTlsStream, PostgresCodec>,
|
||||
|
||||
pub state: ProtoState,
|
||||
|
||||
@@ -139,7 +218,7 @@ pub fn query_from_cstring(query_string: Bytes) -> Vec<u8> {
|
||||
}
|
||||
|
||||
// Cast a byte slice to a string slice, dropping null terminator if there's one.
|
||||
fn cstr_to_str(bytes: &[u8]) -> Result<&str> {
|
||||
fn cstr_to_str(bytes: &[u8]) -> anyhow::Result<&str> {
|
||||
let without_null = bytes.strip_suffix(&[0]).unwrap_or(bytes);
|
||||
std::str::from_utf8(without_null).map_err(|e| e.into())
|
||||
}
|
||||
@@ -149,12 +228,12 @@ impl PostgresBackend {
|
||||
socket: tokio::net::TcpStream,
|
||||
auth_type: AuthType,
|
||||
tls_config: Option<Arc<rustls::ServerConfig>>,
|
||||
) -> std::io::Result<Self> {
|
||||
) -> io::Result<Self> {
|
||||
let peer_addr = socket.peer_addr()?;
|
||||
let stream = MaybeTlsStream::Unencrypted(socket);
|
||||
|
||||
Ok(Self {
|
||||
stream: Stream::Unencrypted(BufReader::new(socket)),
|
||||
buf_out: BytesMut::with_capacity(10 * 1024),
|
||||
stream: Framed::new(stream, PostgresCodec::new()),
|
||||
state: ProtoState::Initialization,
|
||||
auth_type,
|
||||
tls_config,
|
||||
@@ -167,28 +246,60 @@ impl PostgresBackend {
|
||||
}
|
||||
|
||||
/// Read full message or return None if connection is closed.
|
||||
pub async fn read_message(&mut self) -> Result<Option<FeMessage>> {
|
||||
use ProtoState::*;
|
||||
match self.state {
|
||||
Initialization | Encrypted => FeStartupPacket::read_fut(&mut self.stream).await,
|
||||
Authentication | Established => FeMessage::read_fut(&mut self.stream).await,
|
||||
Closed => Ok(None),
|
||||
pub async fn read_message(&mut self) -> Result<Option<FeMessage>, ConnectionError> {
|
||||
if let ProtoState::Closed = self.state {
|
||||
Ok(None)
|
||||
} else {
|
||||
let msg = self.stream.next().await;
|
||||
// Option<Result<...>>, so swap.
|
||||
msg.map_or(Ok(None), |res| res.map(Some))
|
||||
}
|
||||
}
|
||||
|
||||
/// Polling version of read_message, saves the caller need to pin.
|
||||
pub fn poll_read_message(
|
||||
&mut self,
|
||||
cx: &mut std::task::Context<'_>,
|
||||
) -> Poll<Result<Option<FeMessage>, ConnectionError>> {
|
||||
let read_fut = self.read_message();
|
||||
pin_mut!(read_fut);
|
||||
read_fut.poll(cx)
|
||||
}
|
||||
|
||||
/// Flush output buffer into the socket.
|
||||
pub async fn flush(&mut self) -> std::io::Result<()> {
|
||||
while self.buf_out.has_remaining() {
|
||||
let bytes_written = self.stream.write(self.buf_out.chunk()).await?;
|
||||
self.buf_out.advance(bytes_written);
|
||||
}
|
||||
self.buf_out.clear();
|
||||
Ok(())
|
||||
pub async fn flush(&mut self) -> io::Result<()> {
|
||||
self.stream.flush().await.map_err(|e| match e {
|
||||
ConnectionError::Io(e) => e,
|
||||
// the only error we can get from flushing is IO
|
||||
_ => unreachable!(),
|
||||
})
|
||||
}
|
||||
|
||||
/// Write message into internal output buffer.
|
||||
pub fn write_message(&mut self, message: &BeMessage<'_>) -> Result<&mut Self, std::io::Error> {
|
||||
BeMessage::write(&mut self.buf_out, message)?;
|
||||
/// Polling version of `flush()`, saves the caller need to pin.
|
||||
pub fn poll_flush(
|
||||
&mut self,
|
||||
cx: &mut std::task::Context<'_>,
|
||||
) -> Poll<Result<(), std::io::Error>> {
|
||||
let flush_fut = self.flush();
|
||||
pin_mut!(flush_fut);
|
||||
flush_fut.poll(cx)
|
||||
}
|
||||
|
||||
/// Write message into internal output buffer. Technically error type can be
|
||||
/// only ProtocolError here (if, unlikely, serialization fails), but callers
|
||||
/// typically wrap it anyway.
|
||||
pub fn write_message(&mut self, message: &BeMessage<'_>) -> Result<&mut Self, ConnectionError> {
|
||||
Pin::new(&mut self.stream).start_send(message)?;
|
||||
Ok(self)
|
||||
}
|
||||
|
||||
/// Write message into internal output buffer and flush it to the stream.
|
||||
pub async fn write_message_flush(
|
||||
&mut self,
|
||||
message: &BeMessage<'_>,
|
||||
) -> Result<&mut Self, ConnectionError> {
|
||||
self.write_message(message)?;
|
||||
self.flush().await?;
|
||||
Ok(self)
|
||||
}
|
||||
|
||||
@@ -200,36 +311,18 @@ impl PostgresBackend {
|
||||
CopyDataWriter { pgb: self }
|
||||
}
|
||||
|
||||
/// A polling function that tries to write all the data from 'buf_out' to the
|
||||
/// underlying stream.
|
||||
fn poll_write_buf(
|
||||
&mut self,
|
||||
cx: &mut std::task::Context<'_>,
|
||||
) -> Poll<Result<(), std::io::Error>> {
|
||||
while self.buf_out.has_remaining() {
|
||||
match Pin::new(&mut self.stream).poll_write(cx, self.buf_out.chunk()) {
|
||||
Poll::Ready(Ok(bytes_written)) => {
|
||||
self.buf_out.advance(bytes_written);
|
||||
}
|
||||
Poll::Ready(Err(err)) => return Poll::Ready(Err(err)),
|
||||
Poll::Pending => return Poll::Pending,
|
||||
}
|
||||
}
|
||||
Poll::Ready(Ok(()))
|
||||
}
|
||||
|
||||
fn poll_flush(&mut self, cx: &mut std::task::Context<'_>) -> Poll<Result<(), std::io::Error>> {
|
||||
Pin::new(&mut self.stream).poll_flush(cx)
|
||||
}
|
||||
|
||||
// Wrapper for run_message_loop() that shuts down socket when we are done
|
||||
pub async fn run<F, S>(mut self, handler: &mut impl Handler, shutdown_watcher: F) -> Result<()>
|
||||
pub async fn run<F, S>(
|
||||
mut self,
|
||||
handler: &mut impl Handler,
|
||||
shutdown_watcher: F,
|
||||
) -> Result<(), QueryError>
|
||||
where
|
||||
F: Fn() -> S,
|
||||
S: Future,
|
||||
{
|
||||
let ret = self.run_message_loop(handler, shutdown_watcher).await;
|
||||
let _ = self.stream.shutdown();
|
||||
let _ = self.stream.get_mut().shutdown();
|
||||
ret
|
||||
}
|
||||
|
||||
@@ -237,7 +330,7 @@ impl PostgresBackend {
|
||||
&mut self,
|
||||
handler: &mut impl Handler,
|
||||
shutdown_watcher: F,
|
||||
) -> Result<()>
|
||||
) -> Result<(), QueryError>
|
||||
where
|
||||
F: Fn() -> S,
|
||||
S: Future,
|
||||
@@ -273,7 +366,7 @@ impl PostgresBackend {
|
||||
return Ok(());
|
||||
}
|
||||
}
|
||||
Ok::<(), anyhow::Error>(())
|
||||
Ok::<(), QueryError>(())
|
||||
} => {
|
||||
// Handshake complete.
|
||||
result?;
|
||||
@@ -309,34 +402,41 @@ impl PostgresBackend {
|
||||
}
|
||||
|
||||
async fn start_tls(&mut self) -> anyhow::Result<()> {
|
||||
if let Stream::Unencrypted(plain_stream) =
|
||||
std::mem::replace(&mut self.stream, Stream::Broken)
|
||||
if let MaybeTlsStream::Unencrypted(plain_stream) =
|
||||
// temporary replace stream with fake broken to prepare TLS one
|
||||
std::mem::replace(self.stream.get_mut(), MaybeTlsStream::Broken)
|
||||
{
|
||||
let acceptor = TlsAcceptor::from(self.tls_config.clone().unwrap());
|
||||
let tls_stream = acceptor.accept(plain_stream).await?;
|
||||
|
||||
self.stream = Stream::Tls(Box::new(tls_stream));
|
||||
return Ok(());
|
||||
match acceptor.accept(plain_stream).await {
|
||||
Ok(tls_stream) => {
|
||||
// push back ready TLS stream
|
||||
*self.stream.get_mut() = MaybeTlsStream::Tls(Box::new(tls_stream));
|
||||
return Ok(());
|
||||
}
|
||||
Err(e) => {
|
||||
self.state = ProtoState::Closed;
|
||||
return Err(e.into());
|
||||
}
|
||||
}
|
||||
};
|
||||
bail!("TLS already started");
|
||||
anyhow::bail!("TLS already started");
|
||||
}
|
||||
|
||||
async fn process_handshake_message(
|
||||
&mut self,
|
||||
handler: &mut impl Handler,
|
||||
msg: FeMessage,
|
||||
) -> Result<ProcessMsgResult> {
|
||||
) -> Result<ProcessMsgResult, QueryError> {
|
||||
assert!(self.state < ProtoState::Established);
|
||||
let have_tls = self.tls_config.is_some();
|
||||
match msg {
|
||||
FeMessage::StartupPacket(m) => {
|
||||
trace!("got startup message {m:?}");
|
||||
|
||||
match m {
|
||||
FeStartupPacket::SslRequest => {
|
||||
debug!("SSL requested");
|
||||
|
||||
self.write_message(&BeMessage::EncryptionResponse(have_tls))?;
|
||||
|
||||
if have_tls {
|
||||
self.start_tls().await?;
|
||||
self.state = ProtoState::Encrypted;
|
||||
@@ -348,8 +448,13 @@ impl PostgresBackend {
|
||||
}
|
||||
FeStartupPacket::StartupMessage { .. } => {
|
||||
if have_tls && !matches!(self.state, ProtoState::Encrypted) {
|
||||
self.write_message(&BeMessage::ErrorResponse("must connect with TLS"))?;
|
||||
bail!("client did not connect with TLS");
|
||||
self.write_message(&BeMessage::ErrorResponse(
|
||||
"must connect with TLS",
|
||||
None,
|
||||
))?;
|
||||
return Err(QueryError::Other(anyhow::anyhow!(
|
||||
"client did not connect with TLS"
|
||||
)));
|
||||
}
|
||||
|
||||
// NB: startup() may change self.auth_type -- we are using that in proxy code
|
||||
@@ -360,6 +465,7 @@ impl PostgresBackend {
|
||||
AuthType::Trust => {
|
||||
self.write_message(&BeMessage::AuthenticationOk)?
|
||||
.write_message(&BeMessage::CLIENT_ENCODING)?
|
||||
.write_message(&BeMessage::INTEGER_DATETIMES)?
|
||||
// The async python driver requires a valid server_version
|
||||
.write_message(&BeMessage::server_version("14.1"))?
|
||||
.write_message(&BeMessage::ReadyForQuery)?;
|
||||
@@ -389,13 +495,17 @@ impl PostgresBackend {
|
||||
let (_, jwt_response) = m.split_last().context("protocol violation")?;
|
||||
|
||||
if let Err(e) = handler.check_auth_jwt(self, jwt_response) {
|
||||
self.write_message(&BeMessage::ErrorResponse(&e.to_string()))?;
|
||||
bail!("auth failed: {}", e);
|
||||
self.write_message(&BeMessage::ErrorResponse(
|
||||
&e.to_string(),
|
||||
Some(e.pg_error_code()),
|
||||
))?;
|
||||
return Err(e);
|
||||
}
|
||||
}
|
||||
}
|
||||
self.write_message(&BeMessage::AuthenticationOk)?
|
||||
.write_message(&BeMessage::CLIENT_ENCODING)?
|
||||
.write_message(&BeMessage::INTEGER_DATETIMES)?
|
||||
.write_message(&BeMessage::ReadyForQuery)?;
|
||||
self.state = ProtoState::Established;
|
||||
}
|
||||
@@ -413,33 +523,28 @@ impl PostgresBackend {
|
||||
handler: &mut impl Handler,
|
||||
msg: FeMessage,
|
||||
unnamed_query_string: &mut Bytes,
|
||||
) -> Result<ProcessMsgResult> {
|
||||
) -> Result<ProcessMsgResult, QueryError> {
|
||||
// Allow only startup and password messages during auth. Otherwise client would be able to bypass auth
|
||||
// TODO: change that to proper top-level match of protocol state with separate message handling for each state
|
||||
assert!(self.state == ProtoState::Established);
|
||||
|
||||
match msg {
|
||||
FeMessage::StartupPacket(_) | FeMessage::PasswordMessage(_) => {
|
||||
bail!("protocol violation");
|
||||
return Err(QueryError::Other(anyhow::anyhow!("protocol violation")));
|
||||
}
|
||||
|
||||
FeMessage::Query(body) => {
|
||||
// remove null terminator
|
||||
let query_string = cstr_to_str(&body)?;
|
||||
|
||||
trace!("got query {:?}", query_string);
|
||||
// xxx distinguish fatal and recoverable errors?
|
||||
trace!("got query {query_string:?}");
|
||||
if let Err(e) = handler.process_query(self, query_string).await {
|
||||
// ":?" uses the alternate formatting style, which makes anyhow display the
|
||||
// full cause of the error, not just the top-level context + its trace.
|
||||
// We don't want to send that in the ErrorResponse though,
|
||||
// because it's not relevant to the compute node logs.
|
||||
error!("query handler for '{}' failed: {:?}", query_string, e);
|
||||
self.write_message(&BeMessage::ErrorResponse(&e.to_string()))?;
|
||||
// TODO: untangle convoluted control flow
|
||||
if e.to_string().contains("failed to run") {
|
||||
return Ok(ProcessMsgResult::Break);
|
||||
}
|
||||
log_query_error(query_string, &e);
|
||||
let short_error = short_error(&e);
|
||||
self.write_message(&BeMessage::ErrorResponse(
|
||||
&short_error,
|
||||
Some(e.pg_error_code()),
|
||||
))?;
|
||||
}
|
||||
self.write_message(&BeMessage::ReadyForQuery)?;
|
||||
}
|
||||
@@ -464,11 +569,13 @@ impl PostgresBackend {
|
||||
|
||||
FeMessage::Execute(_) => {
|
||||
let query_string = cstr_to_str(unnamed_query_string)?;
|
||||
trace!("got execute {:?}", query_string);
|
||||
// xxx distinguish fatal and recoverable errors?
|
||||
trace!("got execute {query_string:?}");
|
||||
if let Err(e) = handler.process_query(self, query_string).await {
|
||||
error!("query handler for '{}' failed: {:?}", query_string, e);
|
||||
self.write_message(&BeMessage::ErrorResponse(&e.to_string()))?;
|
||||
log_query_error(query_string, &e);
|
||||
self.write_message(&BeMessage::ErrorResponse(
|
||||
&e.to_string(),
|
||||
Some(e.pg_error_code()),
|
||||
))?;
|
||||
}
|
||||
// NOTE there is no ReadyForQuery message. This handler is used
|
||||
// for basebackup and it uses CopyOut which doesn't require
|
||||
@@ -487,7 +594,10 @@ impl PostgresBackend {
|
||||
// We prefer explicit pattern matching to wildcards, because
|
||||
// this helps us spot the places where new variants are missing
|
||||
FeMessage::CopyData(_) | FeMessage::CopyDone | FeMessage::CopyFail => {
|
||||
bail!("unexpected message type: {:?}", msg);
|
||||
return Err(QueryError::Other(anyhow::anyhow!(
|
||||
"unexpected message type: {:?}",
|
||||
msg
|
||||
)));
|
||||
}
|
||||
}
|
||||
|
||||
@@ -515,7 +625,7 @@ impl<'a> AsyncWrite for CopyDataWriter<'a> {
|
||||
// It's not strictly required to flush between each message, but makes it easier
|
||||
// to view in wireshark, and usually the messages that the callers write are
|
||||
// decently-sized anyway.
|
||||
match this.pgb.poll_write_buf(cx) {
|
||||
match this.pgb.poll_flush(cx) {
|
||||
Poll::Ready(Ok(())) => {}
|
||||
Poll::Ready(Err(err)) => return Poll::Ready(Err(err)),
|
||||
Poll::Pending => return Poll::Pending,
|
||||
@@ -525,7 +635,11 @@ impl<'a> AsyncWrite for CopyDataWriter<'a> {
|
||||
// XXX: if the input is large, we should split it into multiple messages.
|
||||
// Not sure what the threshold should be, but the ultimate hard limit is that
|
||||
// the length cannot exceed u32.
|
||||
this.pgb.write_message(&BeMessage::CopyData(buf))?;
|
||||
this.pgb
|
||||
.write_message(&BeMessage::CopyData(buf))
|
||||
// write_message only writes to buffer, so can fail iff message is
|
||||
// invaid, but CopyData can't be invalid.
|
||||
.expect("failed to serialize CopyData");
|
||||
|
||||
Poll::Ready(Ok(buf.len()))
|
||||
}
|
||||
@@ -535,23 +649,39 @@ impl<'a> AsyncWrite for CopyDataWriter<'a> {
|
||||
cx: &mut std::task::Context<'_>,
|
||||
) -> Poll<Result<(), std::io::Error>> {
|
||||
let this = self.get_mut();
|
||||
match this.pgb.poll_write_buf(cx) {
|
||||
Poll::Ready(Ok(())) => {}
|
||||
Poll::Ready(Err(err)) => return Poll::Ready(Err(err)),
|
||||
Poll::Pending => return Poll::Pending,
|
||||
}
|
||||
this.pgb.poll_flush(cx)
|
||||
}
|
||||
|
||||
fn poll_shutdown(
|
||||
self: Pin<&mut Self>,
|
||||
cx: &mut std::task::Context<'_>,
|
||||
) -> Poll<Result<(), std::io::Error>> {
|
||||
let this = self.get_mut();
|
||||
match this.pgb.poll_write_buf(cx) {
|
||||
Poll::Ready(Ok(())) => {}
|
||||
Poll::Ready(Err(err)) => return Poll::Ready(Err(err)),
|
||||
Poll::Pending => return Poll::Pending,
|
||||
}
|
||||
this.pgb.poll_flush(cx)
|
||||
}
|
||||
}
|
||||
|
||||
pub fn short_error(e: &QueryError) -> String {
|
||||
match e {
|
||||
QueryError::Disconnected(connection_error) => connection_error.to_string(),
|
||||
QueryError::Other(e) => format!("{e:#}"),
|
||||
}
|
||||
}
|
||||
|
||||
pub(super) fn log_query_error(query: &str, e: &QueryError) {
|
||||
match e {
|
||||
QueryError::Disconnected(ConnectionError::Io(io_error)) => {
|
||||
if is_expected_io_error(io_error) {
|
||||
info!("query handler for '{query}' failed with expected io error: {io_error}");
|
||||
} else {
|
||||
error!("query handler for '{query}' failed with io error: {io_error}");
|
||||
}
|
||||
}
|
||||
QueryError::Disconnected(other_connection_error) => {
|
||||
error!("query handler for '{query}' failed with connection error: {other_connection_error:?}")
|
||||
}
|
||||
QueryError::Other(e) => {
|
||||
error!("query handler for '{query}' failed: {e:?}");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
116
libs/utils/src/send_rc.rs
Normal file
116
libs/utils/src/send_rc.rs
Normal file
@@ -0,0 +1,116 @@
|
||||
/// Provides Send wrappers of Rc and RefMut.
|
||||
use std::{
|
||||
borrow::Borrow,
|
||||
cell::{Ref, RefCell, RefMut},
|
||||
ops::{Deref, DerefMut},
|
||||
rc::Rc,
|
||||
};
|
||||
|
||||
/// Rc wrapper which is Send.
|
||||
/// This is useful to allow transferring a group of Rcs pointing to the same
|
||||
/// object between threads, e.g. in self referential struct.
|
||||
#[derive(Debug, Eq, PartialEq, Ord, PartialOrd, Hash)]
|
||||
pub struct SendRc<T>
|
||||
where
|
||||
T: ?Sized,
|
||||
{
|
||||
rc: Rc<T>,
|
||||
}
|
||||
|
||||
// SAFETY: Passing Rc(s)<T: Send> between threads is fine as long as there is no
|
||||
// concurrent access to the object they point to, so you must move all such Rcs
|
||||
// together. This appears to be impossible to express in rust type system and
|
||||
// SendRc doesn't provide any additional protection -- but unlike sendable
|
||||
// crate, neither it requires any additional actions before/after move. Ensuring
|
||||
// that sending conforms to the above is the responsibility of the type user.
|
||||
unsafe impl<T: ?Sized + Send> Send for SendRc<T> {}
|
||||
|
||||
impl<T> SendRc<T> {
|
||||
/// Constructs a new SendRc<T>
|
||||
pub fn new(value: T) -> SendRc<T> {
|
||||
SendRc { rc: Rc::new(value) }
|
||||
}
|
||||
}
|
||||
|
||||
// https://stegosaurusdormant.com/understanding-derive-clone/ explains in detail
|
||||
// why derive Clone doesn't work here.
|
||||
impl<T> Clone for SendRc<T> {
|
||||
fn clone(&self) -> Self {
|
||||
SendRc {
|
||||
rc: self.rc.clone(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Deref into inner rc.
|
||||
impl<T> Deref for SendRc<T> {
|
||||
type Target = Rc<T>;
|
||||
|
||||
fn deref(&self) -> &Self::Target {
|
||||
&self.rc
|
||||
}
|
||||
}
|
||||
|
||||
/// Extends RefCell with borrow[_mut] variants which return Sendable Ref[Mut]
|
||||
/// wrappers.
|
||||
pub trait RefCellSend<T: ?Sized> {
|
||||
fn borrow_mut_send(&self) -> RefMutSend<'_, T>;
|
||||
}
|
||||
|
||||
impl<T: Sized> RefCellSend<T> for RefCell<T> {
|
||||
fn borrow_mut_send(&self) -> RefMutSend<'_, T> {
|
||||
RefMutSend {
|
||||
ref_mut: self.borrow_mut(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// RefMut wrapper which is Send. See impl Send for safety. Allows to move a
|
||||
/// RefMut along with RefCell it originates from between threads, e.g. have Send
|
||||
/// Future containing RefMut.
|
||||
#[derive(Debug)]
|
||||
pub struct RefMutSend<'b, T>
|
||||
where
|
||||
T: 'b + ?Sized,
|
||||
{
|
||||
ref_mut: RefMut<'b, T>,
|
||||
}
|
||||
|
||||
// SAFETY: Similar to SendRc, this is safe as long as RefMut stays in the same
|
||||
// thread with original RefCell, so they should be passed together.
|
||||
// Actually, since this is a referential type violating this is not
|
||||
// straightforward; examples of unsafe usage could be
|
||||
// - Passing a RefMut to different thread without source RefCell. Seems only
|
||||
// possible with std::thread::scope.
|
||||
// - Somehow multiple threads get access to single RefCell concurrently,
|
||||
// violating its !Sync requirement. Improper usage of SendRc can do that.
|
||||
unsafe impl<'b, T: ?Sized + Send> Send for RefMutSend<'b, T> {}
|
||||
|
||||
impl<'b, T> RefMutSend<'b, T> {
|
||||
/// Constructs a new RefMutSend<T>
|
||||
pub fn new(ref_mut: RefMut<'b, T>) -> RefMutSend<'b, T> {
|
||||
RefMutSend { ref_mut }
|
||||
}
|
||||
}
|
||||
|
||||
// Deref into inner RefMut.
|
||||
impl<'b, T> Deref for RefMutSend<'b, T>
|
||||
where
|
||||
T: 'b + ?Sized,
|
||||
{
|
||||
type Target = RefMut<'b, T>;
|
||||
|
||||
fn deref<'a>(&'a self) -> &'a RefMut<'b, T> {
|
||||
&self.ref_mut
|
||||
}
|
||||
}
|
||||
|
||||
// DerefMut into inner RefMut.
|
||||
impl<'b, T> DerefMut for RefMutSend<'b, T>
|
||||
where
|
||||
T: 'b + ?Sized,
|
||||
{
|
||||
fn deref_mut<'a>(&'a mut self) -> &'a mut RefMut<'b, T> {
|
||||
&mut self.ref_mut
|
||||
}
|
||||
}
|
||||
@@ -9,7 +9,10 @@ use byteorder::{BigEndian, ReadBytesExt, WriteBytesExt};
|
||||
use bytes::{Buf, BufMut, Bytes, BytesMut};
|
||||
use once_cell::sync::Lazy;
|
||||
|
||||
use utils::postgres_backend::{AuthType, Handler, PostgresBackend};
|
||||
use utils::{
|
||||
postgres_backend::{AuthType, Handler, PostgresBackend},
|
||||
postgres_backend_async::QueryError,
|
||||
};
|
||||
|
||||
fn make_tcp_pair() -> (TcpStream, TcpStream) {
|
||||
let listener = TcpListener::bind("127.0.0.1:0").unwrap();
|
||||
@@ -105,7 +108,7 @@ fn ssl() {
|
||||
&mut self,
|
||||
_pgb: &mut PostgresBackend,
|
||||
query_string: &str,
|
||||
) -> anyhow::Result<()> {
|
||||
) -> Result<(), QueryError> {
|
||||
self.got_query = query_string == QUERY;
|
||||
Ok(())
|
||||
}
|
||||
@@ -152,7 +155,7 @@ fn no_ssl() {
|
||||
&mut self,
|
||||
_pgb: &mut PostgresBackend,
|
||||
_query_string: &str,
|
||||
) -> anyhow::Result<()> {
|
||||
) -> Result<(), QueryError> {
|
||||
panic!()
|
||||
}
|
||||
}
|
||||
@@ -212,7 +215,7 @@ fn server_forces_ssl() {
|
||||
&mut self,
|
||||
_pgb: &mut PostgresBackend,
|
||||
_query_string: &str,
|
||||
) -> anyhow::Result<()> {
|
||||
) -> Result<(), QueryError> {
|
||||
panic!()
|
||||
}
|
||||
}
|
||||
|
||||
@@ -38,6 +38,14 @@ use postgres_ffi::PG_TLI;
|
||||
use postgres_ffi::{BLCKSZ, RELSEG_SIZE, WAL_SEGMENT_SIZE};
|
||||
use utils::lsn::Lsn;
|
||||
|
||||
/// Create basebackup with non-rel data in it.
|
||||
/// Only include relational data if 'full_backup' is true.
|
||||
///
|
||||
/// Currently we use empty 'req_lsn' in two cases:
|
||||
/// * During the basebackup right after timeline creation
|
||||
/// * When working without safekeepers. In this situation it is important to match the lsn
|
||||
/// we are taking basebackup on with the lsn that is used in pageserver's walreceiver
|
||||
/// to start the replication.
|
||||
pub async fn send_basebackup_tarball<'a, W>(
|
||||
write: &'a mut W,
|
||||
timeline: &'a Timeline,
|
||||
@@ -123,14 +131,6 @@ where
|
||||
full_backup: bool,
|
||||
}
|
||||
|
||||
// Create basebackup with non-rel data in it.
|
||||
// Only include relational data if 'full_backup' is true.
|
||||
//
|
||||
// Currently we use empty lsn in two cases:
|
||||
// * During the basebackup right after timeline creation
|
||||
// * When working without safekeepers. In this situation it is important to match the lsn
|
||||
// we are taking basebackup on with the lsn that is used in pageserver's walreceiver
|
||||
// to start the replication.
|
||||
impl<'a, W> Basebackup<'a, W>
|
||||
where
|
||||
W: AsyncWrite + Send + Sync + Unpin,
|
||||
|
||||
@@ -24,7 +24,7 @@ use pageserver::{
|
||||
use utils::{
|
||||
auth::JwtAuth,
|
||||
logging,
|
||||
postgres_backend::AuthType,
|
||||
postgres_backend_async::AuthType,
|
||||
project_git_version,
|
||||
sentry_init::{init_sentry, release_name},
|
||||
signals::{self, Signal},
|
||||
|
||||
@@ -24,7 +24,7 @@ use toml_edit::{Document, Item};
|
||||
use utils::{
|
||||
id::{NodeId, TenantId, TimelineId},
|
||||
logging::LogFormat,
|
||||
postgres_backend::AuthType,
|
||||
postgres_backend_async::AuthType,
|
||||
};
|
||||
|
||||
use crate::tenant::config::TenantConf;
|
||||
|
||||
@@ -738,17 +738,17 @@ async fn timeline_compact_handler(request: Request<Body>) -> Result<Response<Bod
|
||||
let timeline_id: TimelineId = parse_request_param(&request, "timeline_id")?;
|
||||
check_permission(&request, Some(tenant_id))?;
|
||||
|
||||
let tenant = mgr::get_tenant(tenant_id, true)
|
||||
.await
|
||||
.map_err(ApiError::NotFound)?;
|
||||
let timeline = tenant
|
||||
.get_timeline(timeline_id, true)
|
||||
.map_err(ApiError::NotFound)?;
|
||||
timeline
|
||||
.compact()
|
||||
let result_receiver = mgr::immediate_compact(tenant_id, timeline_id)
|
||||
.await
|
||||
.context("spawn compaction task")
|
||||
.map_err(ApiError::InternalServerError)?;
|
||||
|
||||
let result: anyhow::Result<()> = result_receiver
|
||||
.await
|
||||
.context("receive compaction result")
|
||||
.map_err(ApiError::InternalServerError)?;
|
||||
result.map_err(ApiError::InternalServerError)?;
|
||||
|
||||
json_response(StatusCode::OK, ())
|
||||
}
|
||||
|
||||
|
||||
@@ -209,15 +209,34 @@ pub static NUM_ONDISK_LAYERS: Lazy<IntGauge> = Lazy::new(|| {
|
||||
|
||||
// remote storage metrics
|
||||
|
||||
static REMOTE_UPLOAD_QUEUE_UNFINISHED_TASKS: Lazy<IntGaugeVec> = Lazy::new(|| {
|
||||
/// NB: increment _after_ recording the current value into [`REMOTE_TIMELINE_CLIENT_CALLS_STARTED_HIST`].
|
||||
static REMOTE_TIMELINE_CLIENT_CALLS_UNFINISHED_GAUGE: Lazy<IntGaugeVec> = Lazy::new(|| {
|
||||
register_int_gauge_vec!(
|
||||
"pageserver_remote_upload_queue_unfinished_tasks",
|
||||
"Number of tasks in the upload queue that are not finished yet.",
|
||||
"pageserver_remote_timeline_client_calls_unfinished",
|
||||
"Number of ongoing calls to remote timeline client. \
|
||||
Used to populate pageserver_remote_timeline_client_calls_started. \
|
||||
This metric is not useful for sampling from Prometheus, but useful in tests.",
|
||||
&["tenant_id", "timeline_id", "file_kind", "op_kind"],
|
||||
)
|
||||
.expect("failed to define a metric")
|
||||
});
|
||||
|
||||
static REMOTE_TIMELINE_CLIENT_CALLS_STARTED_HIST: Lazy<HistogramVec> = Lazy::new(|| {
|
||||
register_histogram_vec!(
|
||||
"pageserver_remote_timeline_client_calls_started",
|
||||
"When calling a remote timeline client method, we record the current value \
|
||||
of the calls_unfinished gauge in this histogram. Plot the histogram \
|
||||
over time in a heatmap to visualize how many operations were ongoing \
|
||||
at a given instant. It gives you a better idea of the queue depth \
|
||||
than plotting the gauge directly, since operations may complete faster \
|
||||
than the sampling interval.",
|
||||
&["tenant_id", "timeline_id", "file_kind", "op_kind"],
|
||||
// The calls_unfinished gauge is an integer gauge, hence we have integer buckets.
|
||||
vec![0.0, 1.0, 2.0, 4.0, 6.0, 8.0, 10.0, 15.0, 20.0, 40.0, 60.0, 80.0, 100.0, 500.0],
|
||||
)
|
||||
.expect("failed to define a metric")
|
||||
});
|
||||
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
|
||||
pub enum RemoteOpKind {
|
||||
Upload,
|
||||
@@ -248,15 +267,12 @@ impl RemoteOpFileKind {
|
||||
}
|
||||
}
|
||||
|
||||
pub static REMOTE_OPERATION_KINDS: &[&str] = &["upload", "download", "delete"];
|
||||
pub static REMOTE_OPERATION_FILE_KINDS: &[&str] = &["layer", "index"];
|
||||
pub static REMOTE_OPERATION_STATUSES: &[&str] = &["success", "failure"];
|
||||
|
||||
pub static REMOTE_OPERATION_TIME: Lazy<HistogramVec> = Lazy::new(|| {
|
||||
register_histogram_vec!(
|
||||
"pageserver_remote_operation_seconds",
|
||||
"Time spent on remote storage operations. \
|
||||
Grouped by tenant, timeline, operation_kind and status",
|
||||
Grouped by tenant, timeline, operation_kind and status. \
|
||||
Does not account for time spent waiting in remote timeline client's queues.",
|
||||
&["tenant_id", "timeline_id", "file_kind", "op_kind", "status"]
|
||||
)
|
||||
.expect("failed to define a metric")
|
||||
@@ -475,21 +491,6 @@ impl Drop for TimelineMetrics {
|
||||
for op in SMGR_QUERY_TIME_OPERATIONS {
|
||||
let _ = SMGR_QUERY_TIME.remove_label_values(&[op, tenant_id, timeline_id]);
|
||||
}
|
||||
|
||||
let _ = REMOTE_UPLOAD_QUEUE_UNFINISHED_TASKS.remove_label_values(&[tenant_id, timeline_id]);
|
||||
for file_kind in REMOTE_OPERATION_FILE_KINDS {
|
||||
for op in REMOTE_OPERATION_KINDS {
|
||||
for status in REMOTE_OPERATION_STATUSES {
|
||||
let _ = REMOTE_OPERATION_TIME.remove_label_values(&[
|
||||
tenant_id,
|
||||
timeline_id,
|
||||
file_kind,
|
||||
op,
|
||||
status,
|
||||
]);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -510,7 +511,8 @@ pub struct RemoteTimelineClientMetrics {
|
||||
timeline_id: String,
|
||||
remote_physical_size_gauge: Mutex<Option<UIntGauge>>,
|
||||
remote_operation_time: Mutex<HashMap<(&'static str, &'static str, &'static str), Histogram>>,
|
||||
unfinished_tasks: Mutex<HashMap<(&'static str, &'static str), IntGauge>>,
|
||||
calls_unfinished_gauge: Mutex<HashMap<(&'static str, &'static str), IntGauge>>,
|
||||
calls_started_hist: Mutex<HashMap<(&'static str, &'static str), Histogram>>,
|
||||
}
|
||||
|
||||
impl RemoteTimelineClientMetrics {
|
||||
@@ -519,7 +521,8 @@ impl RemoteTimelineClientMetrics {
|
||||
tenant_id: tenant_id.to_string(),
|
||||
timeline_id: timeline_id.to_string(),
|
||||
remote_operation_time: Mutex::new(HashMap::default()),
|
||||
unfinished_tasks: Mutex::new(HashMap::default()),
|
||||
calls_unfinished_gauge: Mutex::new(HashMap::default()),
|
||||
calls_started_hist: Mutex::new(HashMap::default()),
|
||||
remote_physical_size_gauge: Mutex::new(None),
|
||||
}
|
||||
}
|
||||
@@ -558,16 +561,37 @@ impl RemoteTimelineClientMetrics {
|
||||
});
|
||||
metric.clone()
|
||||
}
|
||||
pub fn unfinished_tasks(
|
||||
fn calls_unfinished_gauge(
|
||||
&self,
|
||||
file_kind: &RemoteOpFileKind,
|
||||
op_kind: &RemoteOpKind,
|
||||
) -> IntGauge {
|
||||
// XXX would be nice to have an upgradable RwLock
|
||||
let mut guard = self.unfinished_tasks.lock().unwrap();
|
||||
let mut guard = self.calls_unfinished_gauge.lock().unwrap();
|
||||
let key = (file_kind.as_str(), op_kind.as_str());
|
||||
let metric = guard.entry(key).or_insert_with(move || {
|
||||
REMOTE_UPLOAD_QUEUE_UNFINISHED_TASKS
|
||||
REMOTE_TIMELINE_CLIENT_CALLS_UNFINISHED_GAUGE
|
||||
.get_metric_with_label_values(&[
|
||||
&self.tenant_id.to_string(),
|
||||
&self.timeline_id.to_string(),
|
||||
key.0,
|
||||
key.1,
|
||||
])
|
||||
.unwrap()
|
||||
});
|
||||
metric.clone()
|
||||
}
|
||||
|
||||
fn calls_started_hist(
|
||||
&self,
|
||||
file_kind: &RemoteOpFileKind,
|
||||
op_kind: &RemoteOpKind,
|
||||
) -> Histogram {
|
||||
// XXX would be nice to have an upgradable RwLock
|
||||
let mut guard = self.calls_started_hist.lock().unwrap();
|
||||
let key = (file_kind.as_str(), op_kind.as_str());
|
||||
let metric = guard.entry(key).or_insert_with(move || {
|
||||
REMOTE_TIMELINE_CLIENT_CALLS_STARTED_HIST
|
||||
.get_metric_with_label_values(&[
|
||||
&self.tenant_id.to_string(),
|
||||
&self.timeline_id.to_string(),
|
||||
@@ -580,6 +604,58 @@ impl RemoteTimelineClientMetrics {
|
||||
}
|
||||
}
|
||||
|
||||
/// See [`RemoteTimelineClientMetrics::call_begin`].
|
||||
#[must_use]
|
||||
pub(crate) struct RemoteTimelineClientCallMetricGuard(Option<IntGauge>);
|
||||
|
||||
impl RemoteTimelineClientCallMetricGuard {
|
||||
/// Consume this guard object without decrementing the metric.
|
||||
/// The caller vouches to do this manually, so that the prior increment of the gauge will cancel out.
|
||||
pub fn will_decrement_manually(mut self) {
|
||||
self.0 = None; // prevent drop() from decrementing
|
||||
}
|
||||
}
|
||||
|
||||
impl Drop for RemoteTimelineClientCallMetricGuard {
|
||||
fn drop(&mut self) {
|
||||
if let RemoteTimelineClientCallMetricGuard(Some(guard)) = self {
|
||||
guard.dec();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl RemoteTimelineClientMetrics {
|
||||
/// Increment the metrics that track ongoing calls to the remote timeline client instance.
|
||||
///
|
||||
/// Drop the returned guard object once the operation is finished to decrement the values.
|
||||
/// Or, use [`RemoteTimelineClientCallMetricGuard::will_decrement_manually`] and [`call_end`] if that
|
||||
/// is more suitable.
|
||||
/// Never do both.
|
||||
pub(crate) fn call_begin(
|
||||
&self,
|
||||
file_kind: &RemoteOpFileKind,
|
||||
op_kind: &RemoteOpKind,
|
||||
) -> RemoteTimelineClientCallMetricGuard {
|
||||
let unfinished_metric = self.calls_unfinished_gauge(file_kind, op_kind);
|
||||
self.calls_started_hist(file_kind, op_kind)
|
||||
.observe(unfinished_metric.get() as f64);
|
||||
unfinished_metric.inc();
|
||||
RemoteTimelineClientCallMetricGuard(Some(unfinished_metric))
|
||||
}
|
||||
|
||||
/// Manually decrement the metric instead of using the guard object.
|
||||
/// Using the guard object is generally preferable.
|
||||
/// See [`call_begin`] for more context.
|
||||
pub(crate) fn call_end(&self, file_kind: &RemoteOpFileKind, op_kind: &RemoteOpKind) {
|
||||
let unfinished_metric = self.calls_unfinished_gauge(file_kind, op_kind);
|
||||
debug_assert!(
|
||||
unfinished_metric.get() > 0,
|
||||
"begin and end should cancel out"
|
||||
);
|
||||
unfinished_metric.dec();
|
||||
}
|
||||
}
|
||||
|
||||
impl Drop for RemoteTimelineClientMetrics {
|
||||
fn drop(&mut self) {
|
||||
let RemoteTimelineClientMetrics {
|
||||
@@ -587,13 +663,22 @@ impl Drop for RemoteTimelineClientMetrics {
|
||||
timeline_id,
|
||||
remote_physical_size_gauge,
|
||||
remote_operation_time,
|
||||
unfinished_tasks,
|
||||
calls_unfinished_gauge,
|
||||
calls_started_hist,
|
||||
} = self;
|
||||
for ((a, b, c), _) in remote_operation_time.get_mut().unwrap().drain() {
|
||||
let _ = REMOTE_OPERATION_TIME.remove_label_values(&[tenant_id, timeline_id, a, b, c]);
|
||||
}
|
||||
for ((a, b), _) in unfinished_tasks.get_mut().unwrap().drain() {
|
||||
let _ = REMOTE_UPLOAD_QUEUE_UNFINISHED_TASKS.remove_label_values(&[
|
||||
for ((a, b), _) in calls_unfinished_gauge.get_mut().unwrap().drain() {
|
||||
let _ = REMOTE_TIMELINE_CLIENT_CALLS_UNFINISHED_GAUGE.remove_label_values(&[
|
||||
tenant_id,
|
||||
timeline_id,
|
||||
a,
|
||||
b,
|
||||
]);
|
||||
}
|
||||
for ((a, b), _) in calls_started_hist.get_mut().unwrap().drain() {
|
||||
let _ = REMOTE_TIMELINE_CLIENT_CALLS_STARTED_HIST.remove_label_values(&[
|
||||
tenant_id,
|
||||
timeline_id,
|
||||
a,
|
||||
|
||||
@@ -9,7 +9,7 @@
|
||||
// custom protocol.
|
||||
//
|
||||
|
||||
use anyhow::{bail, ensure, Context, Result};
|
||||
use anyhow::Context;
|
||||
use bytes::Buf;
|
||||
use bytes::Bytes;
|
||||
use futures::{Stream, StreamExt};
|
||||
@@ -19,6 +19,8 @@ use pageserver_api::models::{
|
||||
PagestreamFeMessage, PagestreamGetPageRequest, PagestreamGetPageResponse,
|
||||
PagestreamNblocksRequest, PagestreamNblocksResponse,
|
||||
};
|
||||
use pq_proto::codec::ConnectionError;
|
||||
use pq_proto::FeStartupPacket;
|
||||
use pq_proto::{BeMessage, FeMessage, RowDescriptor};
|
||||
use std::io;
|
||||
use std::net::TcpListener;
|
||||
@@ -28,11 +30,12 @@ use std::sync::Arc;
|
||||
use std::time::Duration;
|
||||
use tracing::*;
|
||||
use utils::id::ConnectionId;
|
||||
use utils::postgres_backend_async::QueryError;
|
||||
use utils::{
|
||||
auth::{Claims, JwtAuth, Scope},
|
||||
id::{TenantId, TimelineId},
|
||||
lsn::Lsn,
|
||||
postgres_backend::AuthType,
|
||||
postgres_backend_async::AuthType,
|
||||
postgres_backend_async::{self, PostgresBackend},
|
||||
simple_rcu::RcuReadGuard,
|
||||
};
|
||||
@@ -60,11 +63,11 @@ fn copyin_stream(pgb: &mut PostgresBackend) -> impl Stream<Item = io::Result<Byt
|
||||
_ = task_mgr::shutdown_watcher() => {
|
||||
// We were requested to shut down.
|
||||
let msg = format!("pageserver is shutting down");
|
||||
let _ = pgb.write_message(&BeMessage::ErrorResponse(&msg));
|
||||
Err(anyhow::anyhow!(msg))
|
||||
let _ = pgb.write_message(&BeMessage::ErrorResponse(&msg, None));
|
||||
Err(QueryError::Other(anyhow::anyhow!(msg)))
|
||||
}
|
||||
|
||||
msg = pgb.read_message() => { msg }
|
||||
msg = pgb.read_message() => { msg.map_err(QueryError::from)}
|
||||
};
|
||||
|
||||
match msg {
|
||||
@@ -74,14 +77,17 @@ fn copyin_stream(pgb: &mut PostgresBackend) -> impl Stream<Item = io::Result<Byt
|
||||
FeMessage::CopyDone => { break },
|
||||
FeMessage::Sync => continue,
|
||||
FeMessage::Terminate => {
|
||||
let msg = format!("client terminated connection with Terminate message during COPY");
|
||||
pgb.write_message(&BeMessage::ErrorResponse(&msg))?;
|
||||
let msg = "client terminated connection with Terminate message during COPY";
|
||||
let query_error = QueryError::Disconnected(ConnectionError::Io(io::Error::new(io::ErrorKind::ConnectionReset, msg)));
|
||||
pgb.write_message(&BeMessage::ErrorResponse(msg, Some(query_error.pg_error_code())))
|
||||
.expect("failed to serialize ErrorResponse");
|
||||
Err(io::Error::new(io::ErrorKind::ConnectionReset, msg))?;
|
||||
break;
|
||||
}
|
||||
m => {
|
||||
let msg = format!("unexpected message {:?}", m);
|
||||
pgb.write_message(&BeMessage::ErrorResponse(&msg))?;
|
||||
let msg = format!("unexpected message {m:?}");
|
||||
pgb.write_message(&BeMessage::ErrorResponse(&msg, None))
|
||||
.expect("failed to serialize ErrorResponse");
|
||||
Err(io::Error::new(io::ErrorKind::Other, msg))?;
|
||||
break;
|
||||
}
|
||||
@@ -91,12 +97,17 @@ fn copyin_stream(pgb: &mut PostgresBackend) -> impl Stream<Item = io::Result<Byt
|
||||
}
|
||||
Ok(None) => {
|
||||
let msg = "client closed connection during COPY";
|
||||
pgb.write_message(&BeMessage::ErrorResponse(msg))?;
|
||||
let query_error = QueryError::Disconnected(ConnectionError::Io(io::Error::new(io::ErrorKind::ConnectionReset, msg)));
|
||||
pgb.write_message(&BeMessage::ErrorResponse(msg, Some(query_error.pg_error_code())))
|
||||
.expect("failed to serialize ErrorResponse");
|
||||
pgb.flush().await?;
|
||||
Err(io::Error::new(io::ErrorKind::ConnectionReset, msg))?;
|
||||
}
|
||||
Err(e) => {
|
||||
Err(io::Error::new(io::ErrorKind::Other, e))?;
|
||||
Err(QueryError::Disconnected(ConnectionError::Io(io_error))) => {
|
||||
Err(io_error)?;
|
||||
}
|
||||
Err(other) => {
|
||||
Err(io::Error::new(io::ErrorKind::Other, other.to_string()))?;
|
||||
}
|
||||
};
|
||||
}
|
||||
@@ -194,23 +205,19 @@ async fn page_service_conn_main(
|
||||
// we've been requested to shut down
|
||||
Ok(())
|
||||
}
|
||||
Err(err) => {
|
||||
let root_cause_io_err_kind = err
|
||||
.root_cause()
|
||||
.downcast_ref::<io::Error>()
|
||||
.map(|e| e.kind());
|
||||
|
||||
Err(QueryError::Disconnected(ConnectionError::Io(io_error))) => {
|
||||
// `ConnectionReset` error happens when the Postgres client closes the connection.
|
||||
// As this disconnection happens quite often and is expected,
|
||||
// we decided to downgrade the logging level to `INFO`.
|
||||
// See: https://github.com/neondatabase/neon/issues/1683.
|
||||
if root_cause_io_err_kind == Some(io::ErrorKind::ConnectionReset) {
|
||||
if io_error.kind() == io::ErrorKind::ConnectionReset {
|
||||
info!("Postgres client disconnected");
|
||||
Ok(())
|
||||
} else {
|
||||
Err(err)
|
||||
Err(io_error).context("Postgres connection error")
|
||||
}
|
||||
}
|
||||
other => other.context("Postgres query error"),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -312,7 +319,7 @@ impl PageServerHandler {
|
||||
Some(FeMessage::CopyData(bytes)) => bytes,
|
||||
Some(FeMessage::Terminate) => break,
|
||||
Some(m) => {
|
||||
bail!("unexpected message: {m:?} during COPY");
|
||||
anyhow::bail!("unexpected message: {m:?} during COPY");
|
||||
}
|
||||
None => break, // client disconnected
|
||||
};
|
||||
@@ -369,7 +376,7 @@ impl PageServerHandler {
|
||||
base_lsn: Lsn,
|
||||
_end_lsn: Lsn,
|
||||
pg_version: u32,
|
||||
) -> anyhow::Result<()> {
|
||||
) -> Result<(), QueryError> {
|
||||
task_mgr::associate_with(Some(tenant_id), Some(timeline_id));
|
||||
// Create empty timeline
|
||||
info!("creating new timeline");
|
||||
@@ -423,11 +430,16 @@ impl PageServerHandler {
|
||||
timeline_id: TimelineId,
|
||||
start_lsn: Lsn,
|
||||
end_lsn: Lsn,
|
||||
) -> anyhow::Result<()> {
|
||||
) -> Result<(), QueryError> {
|
||||
task_mgr::associate_with(Some(tenant_id), Some(timeline_id));
|
||||
|
||||
let timeline = get_active_timeline_with_timeout(tenant_id, timeline_id).await?;
|
||||
ensure!(timeline.get_last_record_lsn() == start_lsn);
|
||||
let last_record_lsn = timeline.get_last_record_lsn();
|
||||
if last_record_lsn != start_lsn {
|
||||
return Err(QueryError::Other(
|
||||
anyhow::anyhow!("Cannot import WAL from Lsn {start_lsn} because timeline does not start from the same lsn: {last_record_lsn}"))
|
||||
);
|
||||
}
|
||||
|
||||
// TODO leave clean state on error. For now you can use detach to clean
|
||||
// up broken state from a failed import.
|
||||
@@ -451,7 +463,11 @@ impl PageServerHandler {
|
||||
}
|
||||
|
||||
// TODO Does it make sense to overshoot?
|
||||
ensure!(timeline.get_last_record_lsn() >= end_lsn);
|
||||
if timeline.get_last_record_lsn() < end_lsn {
|
||||
return Err(QueryError::Other(
|
||||
anyhow::anyhow!("Cannot import WAL from Lsn {start_lsn} because timeline does not start from the same lsn: {last_record_lsn}"))
|
||||
);
|
||||
}
|
||||
|
||||
// Flush data to disk, then upload to s3. No need for a forced checkpoint.
|
||||
// We only want to persist the data, and it doesn't matter if it's in the
|
||||
@@ -480,7 +496,7 @@ impl PageServerHandler {
|
||||
mut lsn: Lsn,
|
||||
latest: bool,
|
||||
latest_gc_cutoff_lsn: &RcuReadGuard<Lsn>,
|
||||
) -> Result<Lsn> {
|
||||
) -> anyhow::Result<Lsn> {
|
||||
if latest {
|
||||
// Latest page version was requested. If LSN is given, it is a hint
|
||||
// to the page server that there have been no modifications to the
|
||||
@@ -511,11 +527,11 @@ impl PageServerHandler {
|
||||
}
|
||||
} else {
|
||||
if lsn == Lsn(0) {
|
||||
bail!("invalid LSN(0) in request");
|
||||
anyhow::bail!("invalid LSN(0) in request");
|
||||
}
|
||||
timeline.wait_lsn(lsn).await?;
|
||||
}
|
||||
ensure!(
|
||||
anyhow::ensure!(
|
||||
lsn >= **latest_gc_cutoff_lsn,
|
||||
"tried to request a page version that was garbage collected. requested at {} gc cutoff {}",
|
||||
lsn, **latest_gc_cutoff_lsn
|
||||
@@ -528,7 +544,7 @@ impl PageServerHandler {
|
||||
&self,
|
||||
timeline: &Timeline,
|
||||
req: &PagestreamExistsRequest,
|
||||
) -> Result<PagestreamBeMessage> {
|
||||
) -> anyhow::Result<PagestreamBeMessage> {
|
||||
let latest_gc_cutoff_lsn = timeline.get_latest_gc_cutoff_lsn();
|
||||
let lsn = Self::wait_or_get_last_lsn(timeline, req.lsn, req.latest, &latest_gc_cutoff_lsn)
|
||||
.await?;
|
||||
@@ -548,7 +564,7 @@ impl PageServerHandler {
|
||||
&self,
|
||||
timeline: &Timeline,
|
||||
req: &PagestreamNblocksRequest,
|
||||
) -> Result<PagestreamBeMessage> {
|
||||
) -> anyhow::Result<PagestreamBeMessage> {
|
||||
let latest_gc_cutoff_lsn = timeline.get_latest_gc_cutoff_lsn();
|
||||
let lsn = Self::wait_or_get_last_lsn(timeline, req.lsn, req.latest, &latest_gc_cutoff_lsn)
|
||||
.await?;
|
||||
@@ -568,7 +584,7 @@ impl PageServerHandler {
|
||||
&self,
|
||||
timeline: &Timeline,
|
||||
req: &PagestreamDbSizeRequest,
|
||||
) -> Result<PagestreamBeMessage> {
|
||||
) -> anyhow::Result<PagestreamBeMessage> {
|
||||
let latest_gc_cutoff_lsn = timeline.get_latest_gc_cutoff_lsn();
|
||||
let lsn = Self::wait_or_get_last_lsn(timeline, req.lsn, req.latest, &latest_gc_cutoff_lsn)
|
||||
.await?;
|
||||
@@ -589,7 +605,7 @@ impl PageServerHandler {
|
||||
&self,
|
||||
timeline: &Timeline,
|
||||
req: &PagestreamGetPageRequest,
|
||||
) -> Result<PagestreamBeMessage> {
|
||||
) -> anyhow::Result<PagestreamBeMessage> {
|
||||
let latest_gc_cutoff_lsn = timeline.get_latest_gc_cutoff_lsn();
|
||||
let lsn = Self::wait_or_get_last_lsn(timeline, req.lsn, req.latest, &latest_gc_cutoff_lsn)
|
||||
.await?;
|
||||
@@ -654,7 +670,7 @@ impl PageServerHandler {
|
||||
|
||||
// when accessing management api supply None as an argument
|
||||
// when using to authorize tenant pass corresponding tenant id
|
||||
fn check_permission(&self, tenant_id: Option<TenantId>) -> Result<()> {
|
||||
fn check_permission(&self, tenant_id: Option<TenantId>) -> anyhow::Result<()> {
|
||||
if self.auth.is_none() {
|
||||
// auth is set to Trust, nothing to check so just return ok
|
||||
return Ok(());
|
||||
@@ -676,20 +692,19 @@ impl postgres_backend_async::Handler for PageServerHandler {
|
||||
&mut self,
|
||||
_pgb: &mut PostgresBackend,
|
||||
jwt_response: &[u8],
|
||||
) -> anyhow::Result<()> {
|
||||
) -> Result<(), QueryError> {
|
||||
// this unwrap is never triggered, because check_auth_jwt only called when auth_type is NeonJWT
|
||||
// which requires auth to be present
|
||||
let data = self
|
||||
.auth
|
||||
.as_ref()
|
||||
.unwrap()
|
||||
.decode(str::from_utf8(jwt_response)?)?;
|
||||
.decode(str::from_utf8(jwt_response).context("jwt response is not UTF-8")?)?;
|
||||
|
||||
if matches!(data.claims.scope, Scope::Tenant) {
|
||||
ensure!(
|
||||
data.claims.tenant_id.is_some(),
|
||||
if matches!(data.claims.scope, Scope::Tenant) && data.claims.tenant_id.is_none() {
|
||||
return Err(QueryError::Other(anyhow::anyhow!(
|
||||
"jwt token scope is Tenant, but tenant id is missing"
|
||||
)
|
||||
)));
|
||||
}
|
||||
|
||||
info!(
|
||||
@@ -701,22 +716,33 @@ impl postgres_backend_async::Handler for PageServerHandler {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn startup(
|
||||
&mut self,
|
||||
_pgb: &mut PostgresBackend,
|
||||
_sm: &FeStartupPacket,
|
||||
) -> Result<(), QueryError> {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn process_query(
|
||||
&mut self,
|
||||
pgb: &mut PostgresBackend,
|
||||
query_string: &str,
|
||||
) -> anyhow::Result<()> {
|
||||
debug!("process query {:?}", query_string);
|
||||
) -> Result<(), QueryError> {
|
||||
debug!("process query {query_string:?}");
|
||||
|
||||
if query_string.starts_with("pagestream ") {
|
||||
let (_, params_raw) = query_string.split_at("pagestream ".len());
|
||||
let params = params_raw.split(' ').collect::<Vec<_>>();
|
||||
ensure!(
|
||||
params.len() == 2,
|
||||
"invalid param number for pagestream command"
|
||||
);
|
||||
let tenant_id = TenantId::from_str(params[0])?;
|
||||
let timeline_id = TimelineId::from_str(params[1])?;
|
||||
if params.len() != 2 {
|
||||
return Err(QueryError::Other(anyhow::anyhow!(
|
||||
"invalid param number for pagestream command"
|
||||
)));
|
||||
}
|
||||
let tenant_id = TenantId::from_str(params[0])
|
||||
.with_context(|| format!("Failed to parse tenant id from {}", params[0]))?;
|
||||
let timeline_id = TimelineId::from_str(params[1])
|
||||
.with_context(|| format!("Failed to parse timeline id from {}", params[1]))?;
|
||||
|
||||
self.check_permission(Some(tenant_id))?;
|
||||
|
||||
@@ -726,18 +752,24 @@ impl postgres_backend_async::Handler for PageServerHandler {
|
||||
let (_, params_raw) = query_string.split_at("basebackup ".len());
|
||||
let params = params_raw.split_whitespace().collect::<Vec<_>>();
|
||||
|
||||
ensure!(
|
||||
params.len() >= 2,
|
||||
"invalid param number for basebackup command"
|
||||
);
|
||||
if params.len() < 2 {
|
||||
return Err(QueryError::Other(anyhow::anyhow!(
|
||||
"invalid param number for basebackup command"
|
||||
)));
|
||||
}
|
||||
|
||||
let tenant_id = TenantId::from_str(params[0])?;
|
||||
let timeline_id = TimelineId::from_str(params[1])?;
|
||||
let tenant_id = TenantId::from_str(params[0])
|
||||
.with_context(|| format!("Failed to parse tenant id from {}", params[0]))?;
|
||||
let timeline_id = TimelineId::from_str(params[1])
|
||||
.with_context(|| format!("Failed to parse timeline id from {}", params[1]))?;
|
||||
|
||||
self.check_permission(Some(tenant_id))?;
|
||||
|
||||
let lsn = if params.len() == 3 {
|
||||
Some(Lsn::from_str(params[2])?)
|
||||
Some(
|
||||
Lsn::from_str(params[2])
|
||||
.with_context(|| format!("Failed to parse Lsn from {}", params[2]))?,
|
||||
)
|
||||
} else {
|
||||
None
|
||||
};
|
||||
@@ -752,13 +784,16 @@ impl postgres_backend_async::Handler for PageServerHandler {
|
||||
let (_, params_raw) = query_string.split_at("get_last_record_rlsn ".len());
|
||||
let params = params_raw.split_whitespace().collect::<Vec<_>>();
|
||||
|
||||
ensure!(
|
||||
params.len() == 2,
|
||||
"invalid param number for get_last_record_rlsn command"
|
||||
);
|
||||
if params.len() != 2 {
|
||||
return Err(QueryError::Other(anyhow::anyhow!(
|
||||
"invalid param number for get_last_record_rlsn command"
|
||||
)));
|
||||
}
|
||||
|
||||
let tenant_id = TenantId::from_str(params[0])?;
|
||||
let timeline_id = TimelineId::from_str(params[1])?;
|
||||
let tenant_id = TenantId::from_str(params[0])
|
||||
.with_context(|| format!("Failed to parse tenant id from {}", params[0]))?;
|
||||
let timeline_id = TimelineId::from_str(params[1])
|
||||
.with_context(|| format!("Failed to parse timeline id from {}", params[1]))?;
|
||||
|
||||
self.check_permission(Some(tenant_id))?;
|
||||
let timeline = get_active_timeline_with_timeout(tenant_id, timeline_id).await?;
|
||||
@@ -780,22 +815,31 @@ impl postgres_backend_async::Handler for PageServerHandler {
|
||||
let (_, params_raw) = query_string.split_at("fullbackup ".len());
|
||||
let params = params_raw.split_whitespace().collect::<Vec<_>>();
|
||||
|
||||
ensure!(
|
||||
params.len() >= 2,
|
||||
"invalid param number for fullbackup command"
|
||||
);
|
||||
if params.len() < 2 {
|
||||
return Err(QueryError::Other(anyhow::anyhow!(
|
||||
"invalid param number for fullbackup command"
|
||||
)));
|
||||
}
|
||||
|
||||
let tenant_id = TenantId::from_str(params[0])?;
|
||||
let timeline_id = TimelineId::from_str(params[1])?;
|
||||
let tenant_id = TenantId::from_str(params[0])
|
||||
.with_context(|| format!("Failed to parse tenant id from {}", params[0]))?;
|
||||
let timeline_id = TimelineId::from_str(params[1])
|
||||
.with_context(|| format!("Failed to parse timeline id from {}", params[1]))?;
|
||||
|
||||
// The caller is responsible for providing correct lsn and prev_lsn.
|
||||
let lsn = if params.len() > 2 {
|
||||
Some(Lsn::from_str(params[2])?)
|
||||
Some(
|
||||
Lsn::from_str(params[2])
|
||||
.with_context(|| format!("Failed to parse Lsn from {}", params[2]))?,
|
||||
)
|
||||
} else {
|
||||
None
|
||||
};
|
||||
let prev_lsn = if params.len() > 3 {
|
||||
Some(Lsn::from_str(params[3])?)
|
||||
Some(
|
||||
Lsn::from_str(params[3])
|
||||
.with_context(|| format!("Failed to parse Lsn from {}", params[3]))?,
|
||||
)
|
||||
} else {
|
||||
None
|
||||
};
|
||||
@@ -820,12 +864,21 @@ impl postgres_backend_async::Handler for PageServerHandler {
|
||||
// -c "import basebackup $TENANT $TIMELINE $START_LSN $END_LSN $PG_VERSION"
|
||||
let (_, params_raw) = query_string.split_at("import basebackup ".len());
|
||||
let params = params_raw.split_whitespace().collect::<Vec<_>>();
|
||||
ensure!(params.len() == 5);
|
||||
let tenant_id = TenantId::from_str(params[0])?;
|
||||
let timeline_id = TimelineId::from_str(params[1])?;
|
||||
let base_lsn = Lsn::from_str(params[2])?;
|
||||
let end_lsn = Lsn::from_str(params[3])?;
|
||||
let pg_version = u32::from_str(params[4])?;
|
||||
if params.len() != 5 {
|
||||
return Err(QueryError::Other(anyhow::anyhow!(
|
||||
"invalid param number for import basebackup command"
|
||||
)));
|
||||
}
|
||||
let tenant_id = TenantId::from_str(params[0])
|
||||
.with_context(|| format!("Failed to parse tenant id from {}", params[0]))?;
|
||||
let timeline_id = TimelineId::from_str(params[1])
|
||||
.with_context(|| format!("Failed to parse timeline id from {}", params[1]))?;
|
||||
let base_lsn = Lsn::from_str(params[2])
|
||||
.with_context(|| format!("Failed to parse Lsn from {}", params[2]))?;
|
||||
let end_lsn = Lsn::from_str(params[3])
|
||||
.with_context(|| format!("Failed to parse Lsn from {}", params[3]))?;
|
||||
let pg_version = u32::from_str(params[4])
|
||||
.with_context(|| format!("Failed to parse pg_version from {}", params[4]))?;
|
||||
|
||||
self.check_permission(Some(tenant_id))?;
|
||||
|
||||
@@ -843,7 +896,10 @@ impl postgres_backend_async::Handler for PageServerHandler {
|
||||
Ok(()) => pgb.write_message(&BeMessage::CommandComplete(b"SELECT 1"))?,
|
||||
Err(e) => {
|
||||
error!("error importing base backup between {base_lsn} and {end_lsn}: {e:?}");
|
||||
pgb.write_message(&BeMessage::ErrorResponse(&e.to_string()))?
|
||||
pgb.write_message(&BeMessage::ErrorResponse(
|
||||
&e.to_string(),
|
||||
Some(e.pg_error_code()),
|
||||
))?
|
||||
}
|
||||
};
|
||||
} else if query_string.starts_with("import wal ") {
|
||||
@@ -853,11 +909,19 @@ impl postgres_backend_async::Handler for PageServerHandler {
|
||||
// caller should poll the http api to check when that is done.
|
||||
let (_, params_raw) = query_string.split_at("import wal ".len());
|
||||
let params = params_raw.split_whitespace().collect::<Vec<_>>();
|
||||
ensure!(params.len() == 4);
|
||||
let tenant_id = TenantId::from_str(params[0])?;
|
||||
let timeline_id = TimelineId::from_str(params[1])?;
|
||||
let start_lsn = Lsn::from_str(params[2])?;
|
||||
let end_lsn = Lsn::from_str(params[3])?;
|
||||
if params.len() != 4 {
|
||||
return Err(QueryError::Other(anyhow::anyhow!(
|
||||
"invalid param number for import wal command"
|
||||
)));
|
||||
}
|
||||
let tenant_id = TenantId::from_str(params[0])
|
||||
.with_context(|| format!("Failed to parse tenant id from {}", params[0]))?;
|
||||
let timeline_id = TimelineId::from_str(params[1])
|
||||
.with_context(|| format!("Failed to parse timeline id from {}", params[1]))?;
|
||||
let start_lsn = Lsn::from_str(params[2])
|
||||
.with_context(|| format!("Failed to parse Lsn from {}", params[2]))?;
|
||||
let end_lsn = Lsn::from_str(params[3])
|
||||
.with_context(|| format!("Failed to parse Lsn from {}", params[3]))?;
|
||||
|
||||
self.check_permission(Some(tenant_id))?;
|
||||
|
||||
@@ -868,7 +932,10 @@ impl postgres_backend_async::Handler for PageServerHandler {
|
||||
Ok(()) => pgb.write_message(&BeMessage::CommandComplete(b"SELECT 1"))?,
|
||||
Err(e) => {
|
||||
error!("error importing WAL between {start_lsn} and {end_lsn}: {e:?}");
|
||||
pgb.write_message(&BeMessage::ErrorResponse(&e.to_string()))?
|
||||
pgb.write_message(&BeMessage::ErrorResponse(
|
||||
&e.to_string(),
|
||||
Some(e.pg_error_code()),
|
||||
))?
|
||||
}
|
||||
};
|
||||
} else if query_string.to_ascii_lowercase().starts_with("set ") {
|
||||
@@ -879,8 +946,13 @@ impl postgres_backend_async::Handler for PageServerHandler {
|
||||
// show <tenant_id>
|
||||
let (_, params_raw) = query_string.split_at("show ".len());
|
||||
let params = params_raw.split(' ').collect::<Vec<_>>();
|
||||
ensure!(params.len() == 1, "invalid param number for config command");
|
||||
let tenant_id = TenantId::from_str(params[0])?;
|
||||
if params.len() != 1 {
|
||||
return Err(QueryError::Other(anyhow::anyhow!(
|
||||
"invalid param number for config command"
|
||||
)));
|
||||
}
|
||||
let tenant_id = TenantId::from_str(params[0])
|
||||
.with_context(|| format!("Failed to parse tenant id from {}", params[0]))?;
|
||||
|
||||
self.check_permission(Some(tenant_id))?;
|
||||
|
||||
@@ -921,7 +993,9 @@ impl postgres_backend_async::Handler for PageServerHandler {
|
||||
]))?
|
||||
.write_message(&BeMessage::CommandComplete(b"SELECT 1"))?;
|
||||
} else {
|
||||
bail!("unknown command");
|
||||
return Err(QueryError::Other(anyhow::anyhow!(
|
||||
"unknown command {query_string}"
|
||||
)));
|
||||
}
|
||||
|
||||
Ok(())
|
||||
@@ -933,7 +1007,7 @@ impl postgres_backend_async::Handler for PageServerHandler {
|
||||
/// If the tenant is Loading, waits for it to become Active, for up to 30 s. That
|
||||
/// ensures that queries don't fail immediately after pageserver startup, because
|
||||
/// all tenants are still loading.
|
||||
async fn get_active_tenant_with_timeout(tenant_id: TenantId) -> Result<Arc<Tenant>> {
|
||||
async fn get_active_tenant_with_timeout(tenant_id: TenantId) -> anyhow::Result<Arc<Tenant>> {
|
||||
let tenant = mgr::get_tenant(tenant_id, false).await?;
|
||||
match tokio::time::timeout(Duration::from_secs(30), tenant.wait_to_become_active()).await {
|
||||
Ok(wait_result) => wait_result
|
||||
@@ -947,7 +1021,7 @@ async fn get_active_tenant_with_timeout(tenant_id: TenantId) -> Result<Arc<Tenan
|
||||
async fn get_active_timeline_with_timeout(
|
||||
tenant_id: TenantId,
|
||||
timeline_id: TimelineId,
|
||||
) -> Result<Arc<Timeline>> {
|
||||
) -> anyhow::Result<Arc<Timeline>> {
|
||||
get_active_tenant_with_timeout(tenant_id)
|
||||
.await
|
||||
.and_then(|tenant| tenant.get_timeline(timeline_id, true))
|
||||
|
||||
@@ -13,11 +13,13 @@
|
||||
|
||||
use anyhow::{bail, Context};
|
||||
use bytes::Bytes;
|
||||
use futures::FutureExt;
|
||||
use futures::Stream;
|
||||
use pageserver_api::models::TimelineState;
|
||||
use remote_storage::DownloadError;
|
||||
use remote_storage::GenericRemoteStorage;
|
||||
use tokio::sync::watch;
|
||||
use tokio::task::JoinSet;
|
||||
use tracing::*;
|
||||
use utils::crashsafe::path_with_suffix_extension;
|
||||
|
||||
@@ -594,7 +596,7 @@ impl Tenant {
|
||||
match tenant_clone.attach().await {
|
||||
Ok(_) => {}
|
||||
Err(e) => {
|
||||
tenant_clone.set_broken();
|
||||
tenant_clone.set_broken(&e.to_string());
|
||||
error!("error attaching tenant: {:?}", e);
|
||||
}
|
||||
}
|
||||
@@ -639,26 +641,62 @@ impl Tenant {
|
||||
.as_ref()
|
||||
.ok_or_else(|| anyhow::anyhow!("cannot attach without remote storage"))?;
|
||||
|
||||
let remote_timelines = remote_timeline_client::list_remote_timelines(
|
||||
let remote_timeline_ids = remote_timeline_client::list_remote_timelines(
|
||||
remote_storage,
|
||||
self.conf,
|
||||
self.tenant_id,
|
||||
)
|
||||
.await?;
|
||||
|
||||
info!("found {} timelines", remote_timelines.len());
|
||||
info!("found {} timelines", remote_timeline_ids.len());
|
||||
|
||||
let mut timeline_ancestors: HashMap<TimelineId, TimelineMetadata> = HashMap::new();
|
||||
let mut index_parts: HashMap<TimelineId, IndexPart> = HashMap::new();
|
||||
for (timeline_id, index_part) in remote_timelines {
|
||||
let remote_metadata = index_part.parse_metadata().with_context(|| {
|
||||
format!(
|
||||
"Failed to parse metadata file from remote storage for tenant {} timeline {}",
|
||||
self.tenant_id, timeline_id
|
||||
)
|
||||
})?;
|
||||
// Download & parse index parts
|
||||
let mut part_downloads = JoinSet::new();
|
||||
for timeline_id in remote_timeline_ids {
|
||||
let client = RemoteTimelineClient::new(
|
||||
remote_storage.clone(),
|
||||
self.conf,
|
||||
self.tenant_id,
|
||||
timeline_id,
|
||||
);
|
||||
part_downloads.spawn(
|
||||
async move {
|
||||
debug!("starting index part download");
|
||||
|
||||
let index_part = client
|
||||
.download_index_file()
|
||||
.await
|
||||
.context("download index file")?;
|
||||
|
||||
let remote_metadata = index_part.parse_metadata().context("parse metadata")?;
|
||||
|
||||
debug!("finished index part download");
|
||||
|
||||
Result::<_, anyhow::Error>::Ok((
|
||||
timeline_id,
|
||||
client,
|
||||
index_part,
|
||||
remote_metadata,
|
||||
))
|
||||
}
|
||||
.map(move |res| {
|
||||
res.with_context(|| format!("download index part for timeline {timeline_id}"))
|
||||
})
|
||||
.instrument(info_span!("download_index_part", timeline=%timeline_id)),
|
||||
);
|
||||
}
|
||||
// Wait for all the download tasks to complete & collect results.
|
||||
let mut remote_clients = HashMap::new();
|
||||
let mut index_parts = HashMap::new();
|
||||
let mut timeline_ancestors = HashMap::new();
|
||||
while let Some(result) = part_downloads.join_next().await {
|
||||
// NB: we already added timeline_id as context to the error
|
||||
let result: Result<_, anyhow::Error> = result.context("joinset task join")?;
|
||||
let (timeline_id, client, index_part, remote_metadata) = result?;
|
||||
debug!("successfully downloaded index part for timeline {timeline_id}");
|
||||
timeline_ancestors.insert(timeline_id, remote_metadata);
|
||||
index_parts.insert(timeline_id, index_part);
|
||||
remote_clients.insert(timeline_id, client);
|
||||
}
|
||||
|
||||
// For every timeline, download the metadata file, scan the local directory,
|
||||
@@ -671,7 +709,7 @@ impl Tenant {
|
||||
timeline_id,
|
||||
index_parts.remove(&timeline_id).unwrap(),
|
||||
remote_metadata,
|
||||
remote_storage.clone(),
|
||||
remote_clients.remove(&timeline_id).unwrap(),
|
||||
)
|
||||
.await
|
||||
.with_context(|| {
|
||||
@@ -714,22 +752,19 @@ impl Tenant {
|
||||
Ok(size)
|
||||
}
|
||||
|
||||
#[instrument(skip(self, index_part, remote_metadata, remote_storage), fields(timeline_id=%timeline_id))]
|
||||
#[instrument(skip_all, fields(timeline_id=%timeline_id))]
|
||||
async fn load_remote_timeline(
|
||||
&self,
|
||||
timeline_id: TimelineId,
|
||||
index_part: IndexPart,
|
||||
remote_metadata: TimelineMetadata,
|
||||
remote_storage: GenericRemoteStorage,
|
||||
remote_client: RemoteTimelineClient,
|
||||
) -> anyhow::Result<()> {
|
||||
info!("downloading index file for timeline {}", timeline_id);
|
||||
tokio::fs::create_dir_all(self.conf.timeline_path(&timeline_id, &self.tenant_id))
|
||||
.await
|
||||
.context("Failed to create new timeline directory")?;
|
||||
|
||||
let remote_client =
|
||||
RemoteTimelineClient::new(remote_storage, self.conf, self.tenant_id, timeline_id)?;
|
||||
|
||||
let ancestor = if let Some(ancestor_id) = remote_metadata.ancestor_timeline() {
|
||||
let timelines = self.timelines.lock().unwrap();
|
||||
Some(Arc::clone(timelines.get(&ancestor_id).ok_or_else(
|
||||
@@ -825,7 +860,7 @@ impl Tenant {
|
||||
match tenant_clone.load().await {
|
||||
Ok(()) => {}
|
||||
Err(err) => {
|
||||
tenant_clone.set_broken();
|
||||
tenant_clone.set_broken(&err.to_string());
|
||||
error!("could not load tenant {tenant_id}: {err:?}");
|
||||
}
|
||||
}
|
||||
@@ -986,18 +1021,14 @@ impl Tenant {
|
||||
None
|
||||
};
|
||||
|
||||
let remote_client = self
|
||||
.remote_storage
|
||||
.as_ref()
|
||||
.map(|remote_storage| {
|
||||
RemoteTimelineClient::new(
|
||||
remote_storage.clone(),
|
||||
self.conf,
|
||||
self.tenant_id,
|
||||
timeline_id,
|
||||
)
|
||||
})
|
||||
.transpose()?;
|
||||
let remote_client = self.remote_storage.as_ref().map(|remote_storage| {
|
||||
RemoteTimelineClient::new(
|
||||
remote_storage.clone(),
|
||||
self.conf,
|
||||
self.tenant_id,
|
||||
timeline_id,
|
||||
)
|
||||
});
|
||||
|
||||
let remote_startup_data = match &remote_client {
|
||||
Some(remote_client) => match remote_client.download_index_file().await {
|
||||
@@ -1465,7 +1496,7 @@ impl Tenant {
|
||||
});
|
||||
}
|
||||
|
||||
pub fn set_broken(&self) {
|
||||
pub fn set_broken(&self, reason: &str) {
|
||||
self.state.send_modify(|current_state| {
|
||||
match *current_state {
|
||||
TenantState::Active => {
|
||||
@@ -1474,18 +1505,22 @@ impl Tenant {
|
||||
// activated should never be marked as broken. We cope with it the best
|
||||
// we can, but it shouldn't happen.
|
||||
*current_state = TenantState::Broken;
|
||||
warn!("Changing Active tenant to Broken state");
|
||||
warn!("Changing Active tenant to Broken state, reason: {}", reason);
|
||||
}
|
||||
TenantState::Broken => {
|
||||
// This shouldn't happen either
|
||||
warn!("Tenant is already broken");
|
||||
warn!("Tenant is already in Broken state");
|
||||
}
|
||||
TenantState::Stopping => {
|
||||
// This shouldn't happen either
|
||||
*current_state = TenantState::Broken;
|
||||
warn!("Marking Stopping tenant as Broken");
|
||||
warn!(
|
||||
"Marking Stopping tenant as Broken state, reason: {}",
|
||||
reason
|
||||
);
|
||||
}
|
||||
TenantState::Loading | TenantState::Attaching => {
|
||||
info!("Setting tenant as Broken state, reason: {}", reason);
|
||||
*current_state = TenantState::Broken;
|
||||
}
|
||||
}
|
||||
@@ -1839,7 +1874,12 @@ impl Tenant {
|
||||
|
||||
utils::failpoint_sleep_millis_async!("gc_iteration_internal_after_getting_gc_timelines");
|
||||
|
||||
info!("starting on {} timelines", gc_timelines.len());
|
||||
// If there is nothing to GC, we don't want any messages in the INFO log.
|
||||
if !gc_timelines.is_empty() {
|
||||
info!("{} timelines need GC", gc_timelines.len());
|
||||
} else {
|
||||
debug!("{} timelines need GC", gc_timelines.len());
|
||||
}
|
||||
|
||||
// Perform GC for each timeline.
|
||||
//
|
||||
@@ -2191,7 +2231,7 @@ impl Tenant {
|
||||
self.conf,
|
||||
tenant_id,
|
||||
new_timeline_id,
|
||||
)?;
|
||||
);
|
||||
remote_client.init_upload_queue_for_empty_remote(&new_metadata)?;
|
||||
Some(remote_client)
|
||||
} else {
|
||||
|
||||
@@ -26,7 +26,7 @@ use std::sync::Arc;
|
||||
use tracing::*;
|
||||
use utils::lsn::Lsn;
|
||||
|
||||
use super::storage_layer::{InMemoryLayer, InMemoryOrHistoricLayer, Layer};
|
||||
use super::storage_layer::{InMemoryLayer, Layer};
|
||||
|
||||
///
|
||||
/// LayerMap tracks what layers exist on a timeline.
|
||||
@@ -241,8 +241,7 @@ where
|
||||
|
||||
/// Return value of LayerMap::search
|
||||
pub struct SearchResult<L: ?Sized> {
|
||||
// FIXME: I wish this could be Arc<dyn Layer>. But I couldn't make that work.
|
||||
pub layer: InMemoryOrHistoricLayer<L>,
|
||||
pub layer: Arc<L>,
|
||||
pub lsn_floor: Lsn,
|
||||
}
|
||||
|
||||
@@ -261,32 +260,10 @@ where
|
||||
/// contain the version, even if it's missing from the returned
|
||||
/// layer.
|
||||
///
|
||||
/// NOTE: This only searches the 'historic' layers, *not* the
|
||||
/// 'open' and 'frozen' layers!
|
||||
///
|
||||
pub fn search(&self, key: Key, end_lsn: Lsn) -> Option<SearchResult<L>> {
|
||||
// First check if an open or frozen layer matches
|
||||
if let Some(open_layer) = &self.open_layer {
|
||||
let start_lsn = open_layer.get_lsn_range().start;
|
||||
if end_lsn > start_lsn {
|
||||
return Some(SearchResult {
|
||||
layer: InMemoryOrHistoricLayer::InMemory(Arc::clone(open_layer)),
|
||||
lsn_floor: start_lsn,
|
||||
});
|
||||
}
|
||||
}
|
||||
for frozen_layer in self.frozen_layers.iter().rev() {
|
||||
let start_lsn = frozen_layer.get_lsn_range().start;
|
||||
if end_lsn > start_lsn {
|
||||
return Some(SearchResult {
|
||||
layer: InMemoryOrHistoricLayer::InMemory(Arc::clone(frozen_layer)),
|
||||
lsn_floor: start_lsn,
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
self.search_historic(key, end_lsn)
|
||||
}
|
||||
|
||||
fn search_historic(&self, key: Key, end_lsn: Lsn) -> Option<SearchResult<L>> {
|
||||
// linear search
|
||||
// Find the latest image layer that covers the given key
|
||||
let mut latest_img: Option<Arc<L>> = None;
|
||||
let mut latest_img_lsn: Option<Lsn> = None;
|
||||
@@ -311,7 +288,7 @@ where
|
||||
if Lsn(img_lsn.0 + 1) == end_lsn {
|
||||
// found exact match
|
||||
return Some(SearchResult {
|
||||
layer: InMemoryOrHistoricLayer::Historic(Arc::clone(l)),
|
||||
layer: Arc::clone(l),
|
||||
lsn_floor: img_lsn,
|
||||
});
|
||||
}
|
||||
@@ -374,13 +351,13 @@ where
|
||||
);
|
||||
Some(SearchResult {
|
||||
lsn_floor,
|
||||
layer: InMemoryOrHistoricLayer::Historic(l),
|
||||
layer: l,
|
||||
})
|
||||
} else if let Some(l) = latest_img {
|
||||
trace!("found img layer and no deltas for request on {key} at {end_lsn}");
|
||||
Some(SearchResult {
|
||||
lsn_floor: latest_img_lsn.unwrap(),
|
||||
layer: InMemoryOrHistoricLayer::Historic(l),
|
||||
layer: l,
|
||||
})
|
||||
} else {
|
||||
trace!("no layer found for request on {key} at {end_lsn}");
|
||||
|
||||
@@ -430,7 +430,7 @@ where
|
||||
Err(e) => {
|
||||
let tenants_accessor = TENANTS.read().await;
|
||||
match tenants_accessor.get(&tenant_id) {
|
||||
Some(tenant) => tenant.set_broken(),
|
||||
Some(tenant) => tenant.set_broken(&e.to_string()),
|
||||
None => warn!("Tenant {tenant_id} got removed from memory"),
|
||||
}
|
||||
Err(e)
|
||||
@@ -492,3 +492,53 @@ pub async fn immediate_gc(
|
||||
|
||||
Ok(wait_task_done)
|
||||
}
|
||||
|
||||
#[cfg(feature = "testing")]
|
||||
pub async fn immediate_compact(
|
||||
tenant_id: TenantId,
|
||||
timeline_id: TimelineId,
|
||||
) -> Result<tokio::sync::oneshot::Receiver<anyhow::Result<()>>, ApiError> {
|
||||
let guard = TENANTS.read().await;
|
||||
|
||||
let tenant = guard
|
||||
.get(&tenant_id)
|
||||
.map(Arc::clone)
|
||||
.with_context(|| format!("Tenant {tenant_id} not found"))
|
||||
.map_err(ApiError::NotFound)?;
|
||||
|
||||
let timeline = tenant
|
||||
.get_timeline(timeline_id, true)
|
||||
.map_err(ApiError::NotFound)?;
|
||||
|
||||
// Run in task_mgr to avoid race with detach operation
|
||||
let (task_done, wait_task_done) = tokio::sync::oneshot::channel();
|
||||
task_mgr::spawn(
|
||||
&tokio::runtime::Handle::current(),
|
||||
TaskKind::Compaction,
|
||||
Some(tenant_id),
|
||||
Some(timeline_id),
|
||||
&format!(
|
||||
"timeline_compact_handler compaction run for tenant {tenant_id} timeline {timeline_id}"
|
||||
),
|
||||
false,
|
||||
async move {
|
||||
let result = timeline
|
||||
.compact()
|
||||
.instrument(
|
||||
info_span!("manual_compact", tenant = %tenant_id, timeline = %timeline_id),
|
||||
)
|
||||
.await;
|
||||
|
||||
match task_done.send(result) {
|
||||
Ok(_) => (),
|
||||
Err(result) => error!("failed to send compaction result: {result:?}"),
|
||||
}
|
||||
Ok(())
|
||||
},
|
||||
);
|
||||
|
||||
// drop the guard until after we've spawned the task so that timeline shutdown will wait for the task
|
||||
drop(guard);
|
||||
|
||||
Ok(wait_task_done)
|
||||
}
|
||||
|
||||
@@ -298,8 +298,8 @@ impl RemoteTimelineClient {
|
||||
conf: &'static PageServerConf,
|
||||
tenant_id: TenantId,
|
||||
timeline_id: TimelineId,
|
||||
) -> anyhow::Result<RemoteTimelineClient> {
|
||||
Ok(RemoteTimelineClient {
|
||||
) -> RemoteTimelineClient {
|
||||
RemoteTimelineClient {
|
||||
conf,
|
||||
runtime: &BACKGROUND_RUNTIME,
|
||||
tenant_id,
|
||||
@@ -307,7 +307,7 @@ impl RemoteTimelineClient {
|
||||
storage_impl: remote_storage,
|
||||
upload_queue: Mutex::new(UploadQueue::Uninitialized),
|
||||
metrics: Arc::new(RemoteTimelineClientMetrics::new(&tenant_id, &timeline_id)),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
/// Initialize the upload queue for a remote storage that already received
|
||||
@@ -367,6 +367,10 @@ impl RemoteTimelineClient {
|
||||
|
||||
/// Download index file
|
||||
pub async fn download_index_file(&self) -> Result<IndexPart, DownloadError> {
|
||||
let _unfinished_gauge_guard = self
|
||||
.metrics
|
||||
.call_begin(&RemoteOpFileKind::Index, &RemoteOpKind::Download);
|
||||
|
||||
download::download_index_part(
|
||||
self.conf,
|
||||
&self.storage_impl,
|
||||
@@ -393,22 +397,27 @@ impl RemoteTimelineClient {
|
||||
layer_file_name: &LayerFileName,
|
||||
layer_metadata: &LayerFileMetadata,
|
||||
) -> anyhow::Result<u64> {
|
||||
let downloaded_size = download::download_layer_file(
|
||||
self.conf,
|
||||
&self.storage_impl,
|
||||
self.tenant_id,
|
||||
self.timeline_id,
|
||||
layer_file_name,
|
||||
layer_metadata,
|
||||
)
|
||||
.measure_remote_op(
|
||||
self.tenant_id,
|
||||
self.timeline_id,
|
||||
RemoteOpFileKind::Layer,
|
||||
RemoteOpKind::Download,
|
||||
Arc::clone(&self.metrics),
|
||||
)
|
||||
.await?;
|
||||
let downloaded_size = {
|
||||
let _unfinished_gauge_guard = self
|
||||
.metrics
|
||||
.call_begin(&RemoteOpFileKind::Layer, &RemoteOpKind::Download);
|
||||
download::download_layer_file(
|
||||
self.conf,
|
||||
&self.storage_impl,
|
||||
self.tenant_id,
|
||||
self.timeline_id,
|
||||
layer_file_name,
|
||||
layer_metadata,
|
||||
)
|
||||
.measure_remote_op(
|
||||
self.tenant_id,
|
||||
self.timeline_id,
|
||||
RemoteOpFileKind::Layer,
|
||||
RemoteOpKind::Download,
|
||||
Arc::clone(&self.metrics),
|
||||
)
|
||||
.await?
|
||||
};
|
||||
|
||||
// Update the metadata for given layer file. The remote index file
|
||||
// might be missing some information for the file; this allows us
|
||||
@@ -517,7 +526,7 @@ impl RemoteTimelineClient {
|
||||
metadata_bytes,
|
||||
);
|
||||
let op = UploadOp::UploadMetadata(index_part, disk_consistent_lsn);
|
||||
self.update_upload_queue_unfinished_metric(1, &op);
|
||||
self.calls_unfinished_metric_begin(&op);
|
||||
upload_queue.queued_operations.push_back(op);
|
||||
upload_queue.latest_files_changes_since_metadata_upload_scheduled = 0;
|
||||
|
||||
@@ -549,7 +558,7 @@ impl RemoteTimelineClient {
|
||||
upload_queue.latest_files_changes_since_metadata_upload_scheduled += 1;
|
||||
|
||||
let op = UploadOp::UploadLayer(layer_file_name.clone(), layer_metadata.clone());
|
||||
self.update_upload_queue_unfinished_metric(1, &op);
|
||||
self.calls_unfinished_metric_begin(&op);
|
||||
upload_queue.queued_operations.push_back(op);
|
||||
|
||||
info!(
|
||||
@@ -601,7 +610,7 @@ impl RemoteTimelineClient {
|
||||
// schedule the actual deletions
|
||||
for name in names {
|
||||
let op = UploadOp::Delete(RemoteOpFileKind::Layer, name.clone());
|
||||
self.update_upload_queue_unfinished_metric(1, &op);
|
||||
self.calls_unfinished_metric_begin(&op);
|
||||
upload_queue.queued_operations.push_back(op);
|
||||
info!("scheduled layer file deletion {}", name.file_name());
|
||||
}
|
||||
@@ -753,7 +762,7 @@ impl RemoteTimelineClient {
|
||||
// upload finishes or times out soon enough.
|
||||
if task_mgr::is_shutdown_requested() {
|
||||
info!("upload task cancelled by shutdown request");
|
||||
self.update_upload_queue_unfinished_metric(-1, &task.op);
|
||||
self.calls_unfinished_metric_end(&task.op);
|
||||
self.stop();
|
||||
return;
|
||||
}
|
||||
@@ -901,22 +910,40 @@ impl RemoteTimelineClient {
|
||||
// Launch any queued tasks that were unblocked by this one.
|
||||
self.launch_queued_tasks(upload_queue);
|
||||
}
|
||||
self.update_upload_queue_unfinished_metric(-1, &task.op);
|
||||
self.calls_unfinished_metric_end(&task.op);
|
||||
}
|
||||
|
||||
fn update_upload_queue_unfinished_metric(&self, delta: i64, op: &UploadOp) {
|
||||
let (file_kind, op_kind) = match op {
|
||||
fn calls_unfinished_metric_impl(
|
||||
&self,
|
||||
op: &UploadOp,
|
||||
) -> Option<(RemoteOpFileKind, RemoteOpKind)> {
|
||||
let res = match op {
|
||||
UploadOp::UploadLayer(_, _) => (RemoteOpFileKind::Layer, RemoteOpKind::Upload),
|
||||
UploadOp::UploadMetadata(_, _) => (RemoteOpFileKind::Index, RemoteOpKind::Upload),
|
||||
UploadOp::Delete(file_kind, _) => (*file_kind, RemoteOpKind::Delete),
|
||||
UploadOp::Barrier(_) => {
|
||||
// we do not account these
|
||||
return;
|
||||
return None;
|
||||
}
|
||||
};
|
||||
self.metrics
|
||||
.unfinished_tasks(&file_kind, &op_kind)
|
||||
.add(delta)
|
||||
Some(res)
|
||||
}
|
||||
|
||||
fn calls_unfinished_metric_begin(&self, op: &UploadOp) {
|
||||
let (file_kind, op_kind) = match self.calls_unfinished_metric_impl(op) {
|
||||
Some(x) => x,
|
||||
None => return,
|
||||
};
|
||||
let guard = self.metrics.call_begin(&file_kind, &op_kind);
|
||||
guard.will_decrement_manually(); // in unfinished_ops_metric_end()
|
||||
}
|
||||
|
||||
fn calls_unfinished_metric_end(&self, op: &UploadOp) {
|
||||
let (file_kind, op_kind) = match self.calls_unfinished_metric_impl(op) {
|
||||
Some(x) => x,
|
||||
None => return,
|
||||
};
|
||||
self.metrics.call_end(&file_kind, &op_kind);
|
||||
}
|
||||
|
||||
fn stop(&self) {
|
||||
@@ -967,7 +994,7 @@ impl RemoteTimelineClient {
|
||||
|
||||
// Tear down queued ops
|
||||
for op in qi.queued_operations.into_iter() {
|
||||
self.update_upload_queue_unfinished_metric(-1, &op);
|
||||
self.calls_unfinished_metric_end(&op);
|
||||
// Dropping UploadOp::Barrier() here will make wait_completion() return with an Err()
|
||||
// which is exactly what we want to happen.
|
||||
drop(op);
|
||||
|
||||
@@ -8,10 +8,9 @@ use std::future::Future;
|
||||
use std::path::Path;
|
||||
|
||||
use anyhow::{anyhow, Context};
|
||||
use futures::stream::{FuturesUnordered, StreamExt};
|
||||
use tokio::fs;
|
||||
use tokio::io::AsyncWriteExt;
|
||||
use tracing::{debug, error, info, info_span, warn, Instrument};
|
||||
use tracing::{error, info, warn};
|
||||
|
||||
use crate::config::PageServerConf;
|
||||
use crate::tenant::storage_layer::LayerFileName;
|
||||
@@ -175,7 +174,7 @@ pub async fn list_remote_timelines<'a>(
|
||||
storage: &'a GenericRemoteStorage,
|
||||
conf: &'static PageServerConf,
|
||||
tenant_id: TenantId,
|
||||
) -> anyhow::Result<Vec<(TimelineId, IndexPart)>> {
|
||||
) -> anyhow::Result<HashSet<TimelineId>> {
|
||||
let tenant_path = conf.timelines_path(&tenant_id);
|
||||
let tenant_storage_path = conf.remote_path(&tenant_path)?;
|
||||
|
||||
@@ -194,7 +193,6 @@ pub async fn list_remote_timelines<'a>(
|
||||
}
|
||||
|
||||
let mut timeline_ids = HashSet::new();
|
||||
let mut part_downloads = FuturesUnordered::new();
|
||||
|
||||
for timeline_remote_storage_key in timelines {
|
||||
let object_name = timeline_remote_storage_key.object_name().ok_or_else(|| {
|
||||
@@ -205,35 +203,22 @@ pub async fn list_remote_timelines<'a>(
|
||||
format!("failed to parse object name into timeline id '{object_name}'")
|
||||
})?;
|
||||
|
||||
// list_prefixes returns all files with the prefix. If we haven't seen this timeline ID
|
||||
// yet, launch a download task for it.
|
||||
if !timeline_ids.contains(&timeline_id) {
|
||||
timeline_ids.insert(timeline_id);
|
||||
let storage_clone = storage.clone();
|
||||
part_downloads.push(async move {
|
||||
(
|
||||
timeline_id,
|
||||
download_index_part(conf, &storage_clone, tenant_id, timeline_id)
|
||||
.instrument(info_span!("download_index_part", timeline=%timeline_id))
|
||||
.await,
|
||||
)
|
||||
});
|
||||
}
|
||||
// list_prefixes is assumed to return unique names. Ensure this here.
|
||||
// NB: it's safer to bail out than warn-log this because the pageserver
|
||||
// needs to absolutely know about _all_ timelines that exist, so that
|
||||
// GC knows all the branchpoints. If we skipped over a timeline instead,
|
||||
// GC could delete a layer that's still needed by that timeline.
|
||||
anyhow::ensure!(
|
||||
!timeline_ids.contains(&timeline_id),
|
||||
"list_prefixes contains duplicate timeline id {timeline_id}"
|
||||
);
|
||||
timeline_ids.insert(timeline_id);
|
||||
}
|
||||
|
||||
// Wait for all the download tasks to complete.
|
||||
let mut timeline_parts = Vec::new();
|
||||
while let Some((timeline_id, part_upload_result)) = part_downloads.next().await {
|
||||
let index_part = part_upload_result
|
||||
.with_context(|| format!("Failed to fetch index part for timeline {timeline_id}"))?;
|
||||
|
||||
debug!("Successfully fetched index part for timeline {timeline_id}");
|
||||
timeline_parts.push((timeline_id, index_part));
|
||||
}
|
||||
Ok(timeline_parts)
|
||||
Ok(timeline_ids)
|
||||
}
|
||||
|
||||
pub async fn download_index_part(
|
||||
pub(super) async fn download_index_part(
|
||||
conf: &'static PageServerConf,
|
||||
storage: &GenericRemoteStorage,
|
||||
tenant_id: TenantId,
|
||||
|
||||
@@ -196,38 +196,3 @@ pub fn downcast_remote_layer(
|
||||
None
|
||||
}
|
||||
}
|
||||
|
||||
pub enum InMemoryOrHistoricLayer<L: ?Sized> {
|
||||
InMemory(Arc<InMemoryLayer>),
|
||||
Historic(Arc<L>),
|
||||
}
|
||||
|
||||
impl<L: ?Sized> InMemoryOrHistoricLayer<L>
|
||||
where
|
||||
L: PersistentLayer,
|
||||
{
|
||||
pub fn downcast_remote_layer(&self) -> Option<std::sync::Arc<RemoteLayer>> {
|
||||
match self {
|
||||
Self::InMemory(_) => None,
|
||||
Self::Historic(l) => {
|
||||
if l.is_remote_layer() {
|
||||
Arc::clone(l).downcast_remote_layer()
|
||||
} else {
|
||||
None
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub fn get_value_reconstruct_data(
|
||||
&self,
|
||||
key: Key,
|
||||
lsn_range: Range<Lsn>,
|
||||
reconstruct_data: &mut ValueReconstructState,
|
||||
) -> Result<ValueReconstructResult> {
|
||||
match self {
|
||||
Self::InMemory(l) => l.get_value_reconstruct_data(key, lsn_range, reconstruct_data),
|
||||
Self::Historic(l) => l.get_value_reconstruct_data(key, lsn_range, reconstruct_data),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -25,8 +25,8 @@ use std::time::{Duration, Instant, SystemTime};
|
||||
|
||||
use crate::tenant::remote_timeline_client::{self, index::LayerFileMetadata};
|
||||
use crate::tenant::storage_layer::{
|
||||
DeltaFileName, DeltaLayerWriter, ImageFileName, ImageLayerWriter, InMemoryLayer,
|
||||
InMemoryOrHistoricLayer, LayerFileName, RemoteLayer,
|
||||
DeltaFileName, DeltaLayerWriter, ImageFileName, ImageLayerWriter, InMemoryLayer, LayerFileName,
|
||||
RemoteLayer,
|
||||
};
|
||||
use crate::tenant::{
|
||||
ephemeral_file::is_ephemeral_file,
|
||||
@@ -1591,7 +1591,7 @@ trait TraversalLayerExt {
|
||||
fn traversal_id(&self) -> TraversalId;
|
||||
}
|
||||
|
||||
impl<T: PersistentLayer + ?Sized> TraversalLayerExt for T {
|
||||
impl TraversalLayerExt for Arc<dyn PersistentLayer> {
|
||||
fn traversal_id(&self) -> TraversalId {
|
||||
match self.local_path() {
|
||||
Some(local_path) => {
|
||||
@@ -1621,15 +1621,6 @@ impl TraversalLayerExt for Arc<InMemoryLayer> {
|
||||
}
|
||||
}
|
||||
|
||||
impl TraversalLayerExt for InMemoryOrHistoricLayer<dyn PersistentLayer> {
|
||||
fn traversal_id(&self) -> String {
|
||||
match self {
|
||||
Self::InMemory(l) => l.traversal_id(),
|
||||
Self::Historic(l) => l.traversal_id(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Timeline {
|
||||
///
|
||||
/// Get a handle to a Layer for reading.
|
||||
@@ -1651,11 +1642,7 @@ impl Timeline {
|
||||
|
||||
// For debugging purposes, collect the path of layers that we traversed
|
||||
// through. It's included in the error message if we fail to find the key.
|
||||
let mut traversal_path = Vec::<(
|
||||
ValueReconstructResult,
|
||||
Lsn,
|
||||
Box<dyn TraversalLayerExt>,
|
||||
)>::new();
|
||||
let mut traversal_path = Vec::<TraversalPathItem>::new();
|
||||
|
||||
let cached_lsn = if let Some((cached_lsn, _)) = &reconstruct_state.img {
|
||||
*cached_lsn
|
||||
@@ -1691,7 +1678,7 @@ impl Timeline {
|
||||
Lsn(cont_lsn.0 - 1),
|
||||
request_lsn,
|
||||
timeline.ancestor_lsn
|
||||
), &traversal_path);
|
||||
), traversal_path);
|
||||
}
|
||||
prev_lsn = cont_lsn;
|
||||
}
|
||||
@@ -1701,7 +1688,7 @@ impl Timeline {
|
||||
"could not find data for key {} at LSN {}, for request at LSN {}",
|
||||
key, cont_lsn, request_lsn
|
||||
),
|
||||
&traversal_path,
|
||||
traversal_path,
|
||||
);
|
||||
}
|
||||
}
|
||||
@@ -1723,15 +1710,74 @@ impl Timeline {
|
||||
continue 'outer;
|
||||
}
|
||||
|
||||
loop {
|
||||
#[allow(clippy::never_loop)] // see comment at bottom of this loop
|
||||
'_layer_map_search: loop {
|
||||
let remote_layer = {
|
||||
let layers = timeline.layers.read().unwrap();
|
||||
|
||||
// Check the open and frozen in-memory layers first, in order from newest
|
||||
// to oldest.
|
||||
if let Some(open_layer) = &layers.open_layer {
|
||||
let start_lsn = open_layer.get_lsn_range().start;
|
||||
if cont_lsn > start_lsn {
|
||||
//info!("CHECKING for {} at {} on open layer {}", key, cont_lsn, open_layer.filename().display());
|
||||
// Get all the data needed to reconstruct the page version from this layer.
|
||||
// But if we have an older cached page image, no need to go past that.
|
||||
let lsn_floor = max(cached_lsn + 1, start_lsn);
|
||||
result = match open_layer.get_value_reconstruct_data(
|
||||
key,
|
||||
lsn_floor..cont_lsn,
|
||||
reconstruct_state,
|
||||
) {
|
||||
Ok(result) => result,
|
||||
Err(e) => return PageReconstructResult::from(e),
|
||||
};
|
||||
cont_lsn = lsn_floor;
|
||||
traversal_path.push((
|
||||
result,
|
||||
cont_lsn,
|
||||
Box::new({
|
||||
let open_layer = Arc::clone(open_layer);
|
||||
move || open_layer.traversal_id()
|
||||
}),
|
||||
));
|
||||
continue 'outer;
|
||||
}
|
||||
}
|
||||
for frozen_layer in layers.frozen_layers.iter().rev() {
|
||||
let start_lsn = frozen_layer.get_lsn_range().start;
|
||||
if cont_lsn > start_lsn {
|
||||
//info!("CHECKING for {} at {} on frozen layer {}", key, cont_lsn, frozen_layer.filename().display());
|
||||
let lsn_floor = max(cached_lsn + 1, start_lsn);
|
||||
result = match frozen_layer.get_value_reconstruct_data(
|
||||
key,
|
||||
lsn_floor..cont_lsn,
|
||||
reconstruct_state,
|
||||
) {
|
||||
Ok(result) => result,
|
||||
Err(e) => return PageReconstructResult::from(e),
|
||||
};
|
||||
cont_lsn = lsn_floor;
|
||||
traversal_path.push((
|
||||
result,
|
||||
cont_lsn,
|
||||
Box::new({
|
||||
let frozen_layer = Arc::clone(frozen_layer);
|
||||
move || frozen_layer.traversal_id()
|
||||
}),
|
||||
));
|
||||
continue 'outer;
|
||||
}
|
||||
}
|
||||
|
||||
if let Some(SearchResult { lsn_floor, layer }) = layers.search(key, cont_lsn) {
|
||||
// If it's a remote layer, download it and retry.
|
||||
if let Some(remote_layer) = layer.downcast_remote_layer() {
|
||||
if let Some(remote_layer) =
|
||||
super::storage_layer::downcast_remote_layer(&layer)
|
||||
{
|
||||
// TODO: push a breadcrumb to 'traversal_path' to record the fact that
|
||||
// we downloaded / would need to download this.
|
||||
remote_layer
|
||||
// we downloaded / would need to download this layer.
|
||||
remote_layer // download happens outside the scope of `layers` guard object
|
||||
} else {
|
||||
// Get all the data needed to reconstruct the page version from this layer.
|
||||
// But if we have an older cached page image, no need to go past that.
|
||||
@@ -1745,7 +1791,14 @@ impl Timeline {
|
||||
Err(e) => return PageReconstructResult::from(e),
|
||||
};
|
||||
cont_lsn = lsn_floor;
|
||||
traversal_path.push((result, cont_lsn, Box::new(layer)));
|
||||
traversal_path.push((
|
||||
result,
|
||||
cont_lsn,
|
||||
Box::new({
|
||||
let layer = Arc::clone(&layer);
|
||||
move || layer.traversal_id()
|
||||
}),
|
||||
));
|
||||
continue 'outer;
|
||||
}
|
||||
} else if timeline.ancestor_timeline.is_some() {
|
||||
@@ -1759,11 +1812,23 @@ impl Timeline {
|
||||
continue 'outer;
|
||||
}
|
||||
};
|
||||
|
||||
// The next layer doesn't exist locally. The caller can do the download and retry.
|
||||
// (The control flow is a bit complicated here because we must drop the 'layers'
|
||||
// lock before awaiting on the Future.)
|
||||
info!("need remote layer {}", remote_layer.traversal_id());
|
||||
// Indicate to the caller that we need remote_layer replaced with a downloaded
|
||||
// layer in the layer map. The control flow could be a lot simpler, but the point
|
||||
// of this commit is to prepare this function to
|
||||
// 1. become async
|
||||
// 2. do the download right here, using
|
||||
// ```
|
||||
// download_remote_layer().await?;
|
||||
// continue 'layer_map_search;
|
||||
// ```
|
||||
// For (2), current rustc requires that the layers lock guard is not in scope.
|
||||
// Hence, the complicated control flow.
|
||||
let remote_layer_as_persistent: Arc<dyn PersistentLayer> =
|
||||
Arc::clone(&remote_layer) as Arc<dyn PersistentLayer>;
|
||||
info!(
|
||||
"need remote layer {}",
|
||||
remote_layer_as_persistent.traversal_id()
|
||||
);
|
||||
return PageReconstructResult::NeedsDownload(
|
||||
Weak::clone(&timeline.myself),
|
||||
Arc::downgrade(&remote_layer),
|
||||
@@ -3342,22 +3407,25 @@ where
|
||||
}
|
||||
}
|
||||
|
||||
type TraversalPathItem = (
|
||||
ValueReconstructResult,
|
||||
Lsn,
|
||||
Box<dyn FnOnce() -> TraversalId>,
|
||||
);
|
||||
|
||||
/// Helper function for get_reconstruct_data() to add the path of layers traversed
|
||||
/// to an error, as anyhow context information.
|
||||
fn layer_traversal_error(
|
||||
msg: String,
|
||||
path: &[(ValueReconstructResult, Lsn, Box<dyn TraversalLayerExt>)],
|
||||
) -> PageReconstructResult<()> {
|
||||
fn layer_traversal_error(msg: String, path: Vec<TraversalPathItem>) -> PageReconstructResult<()> {
|
||||
// We want the original 'msg' to be the outermost context. The outermost context
|
||||
// is the most high-level information, which also gets propagated to the client.
|
||||
let mut msg_iter = path
|
||||
.iter()
|
||||
.into_iter()
|
||||
.map(|(r, c, l)| {
|
||||
format!(
|
||||
"layer traversal: result {:?}, cont_lsn {}, layer: {}",
|
||||
r,
|
||||
c,
|
||||
l.traversal_id(),
|
||||
l(),
|
||||
)
|
||||
})
|
||||
.chain(std::iter::once(msg));
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
//! Actual Postgres connection handler to stream WAL to the server.
|
||||
|
||||
use std::{
|
||||
error::Error,
|
||||
str::FromStr,
|
||||
sync::Arc,
|
||||
time::{Duration, SystemTime},
|
||||
@@ -11,7 +12,7 @@ use bytes::BytesMut;
|
||||
use chrono::{NaiveDateTime, Utc};
|
||||
use fail::fail_point;
|
||||
use futures::StreamExt;
|
||||
use postgres::{SimpleQueryMessage, SimpleQueryRow};
|
||||
use postgres::{error::SqlState, SimpleQueryMessage, SimpleQueryRow};
|
||||
use postgres_ffi::v14::xlog_utils::normalize_lsn;
|
||||
use postgres_ffi::WAL_SEGMENT_SIZE;
|
||||
use postgres_protocol::message::backend::ReplicationMessage;
|
||||
@@ -32,7 +33,7 @@ use crate::{
|
||||
use postgres_connection::PgConnectionConfig;
|
||||
use postgres_ffi::waldecoder::WalStreamDecoder;
|
||||
use pq_proto::ReplicationFeedback;
|
||||
use utils::lsn::Lsn;
|
||||
use utils::{lsn::Lsn, postgres_backend_async::is_expected_io_error};
|
||||
|
||||
/// Status of the connection.
|
||||
#[derive(Debug, Clone, Copy)]
|
||||
@@ -68,10 +69,17 @@ pub async fn handle_walreceiver_connection(
|
||||
let mut config = wal_source_connconf.to_tokio_postgres_config();
|
||||
config.application_name("pageserver");
|
||||
config.replication_mode(tokio_postgres::config::ReplicationMode::Physical);
|
||||
time::timeout(connect_timeout, config.connect(postgres::NoTls))
|
||||
.await
|
||||
.context("Timed out while waiting for walreceiver connection to open")?
|
||||
.context("Failed to open walreceiver connection")?
|
||||
match time::timeout(connect_timeout, config.connect(postgres::NoTls)).await {
|
||||
Ok(Ok(client_and_conn)) => client_and_conn,
|
||||
Ok(Err(conn_err)) => {
|
||||
let expected_error = ignore_expected_errors(conn_err)?;
|
||||
info!("DB connection stream finished: {expected_error}");
|
||||
return Ok(());
|
||||
}
|
||||
Err(elapsed) => anyhow::bail!(
|
||||
"Timed out while waiting {elapsed} for walreceiver connection to open"
|
||||
),
|
||||
}
|
||||
};
|
||||
|
||||
info!("connected!");
|
||||
@@ -103,10 +111,8 @@ pub async fn handle_walreceiver_connection(
|
||||
connection_result = connection => match connection_result{
|
||||
Ok(()) => info!("Walreceiver db connection closed"),
|
||||
Err(connection_error) => {
|
||||
if connection_error.is_closed() {
|
||||
info!("Connection closed regularly: {connection_error}")
|
||||
} else {
|
||||
warn!("Connection aborted: {connection_error}")
|
||||
if let Err(e) = ignore_expected_errors(connection_error) {
|
||||
warn!("Connection aborted: {e:#}")
|
||||
}
|
||||
}
|
||||
},
|
||||
@@ -187,14 +193,9 @@ pub async fn handle_walreceiver_connection(
|
||||
let replication_message = match replication_message {
|
||||
Ok(message) => message,
|
||||
Err(replication_error) => {
|
||||
if replication_error.is_closed() {
|
||||
info!("Replication stream got closed");
|
||||
return Ok(());
|
||||
} else {
|
||||
return Err(
|
||||
anyhow::Error::new(replication_error).context("replication stream error")
|
||||
);
|
||||
}
|
||||
let expected_error = ignore_expected_errors(replication_error)?;
|
||||
info!("Replication stream finished: {expected_error}");
|
||||
return Ok(());
|
||||
}
|
||||
};
|
||||
|
||||
@@ -400,3 +401,32 @@ async fn identify_system(client: &mut Client) -> anyhow::Result<IdentifySystem>
|
||||
Err(IdentifyError.into())
|
||||
}
|
||||
}
|
||||
|
||||
/// We don't want to report connectivity problems as real errors towards connection manager because
|
||||
/// 1. they happen frequently enough to make server logs hard to read and
|
||||
/// 2. the connection manager can retry other safekeeper.
|
||||
///
|
||||
/// If this function returns `Ok(pg_error)`, it's such an error.
|
||||
/// The caller should log it at info level and then report to connection manager that we're done handling this connection.
|
||||
/// Connection manager will then handle reconnections.
|
||||
///
|
||||
/// If this function returns an `Err()`, the caller can bubble it up using `?`.
|
||||
/// The connection manager will log the error at ERROR level.
|
||||
fn ignore_expected_errors(pg_error: postgres::Error) -> anyhow::Result<postgres::Error> {
|
||||
if pg_error.is_closed()
|
||||
|| pg_error
|
||||
.source()
|
||||
.and_then(|source| source.downcast_ref::<std::io::Error>())
|
||||
.map(is_expected_io_error)
|
||||
.unwrap_or(false)
|
||||
{
|
||||
return Ok(pg_error);
|
||||
} else if let Some(db_error) = pg_error.as_db_error() {
|
||||
if db_error.code() == &SqlState::CONNECTION_FAILURE
|
||||
&& db_error.message().contains("end streaming")
|
||||
{
|
||||
return Ok(pg_error);
|
||||
}
|
||||
}
|
||||
Err(pg_error).context("connection error")
|
||||
}
|
||||
|
||||
@@ -111,6 +111,7 @@ pageserver_connect()
|
||||
PQfinish(pageserver_conn);
|
||||
pageserver_conn = NULL;
|
||||
FreeWaitEventSet(pageserver_conn_wes);
|
||||
pageserver_conn_wes = NULL;
|
||||
|
||||
neon_log(ERROR, "could not complete handshake with pageserver: %s",
|
||||
msg);
|
||||
@@ -179,7 +180,10 @@ pageserver_disconnect(void)
|
||||
prefetch_on_ps_disconnect();
|
||||
}
|
||||
if (pageserver_conn_wes != NULL)
|
||||
{
|
||||
FreeWaitEventSet(pageserver_conn_wes);
|
||||
pageserver_conn_wes = NULL;
|
||||
}
|
||||
}
|
||||
|
||||
static void
|
||||
@@ -206,7 +210,7 @@ pageserver_send(NeonRequest * request)
|
||||
*/
|
||||
if (PQputCopyData(pageserver_conn, req_buff.data, req_buff.len) <= 0)
|
||||
{
|
||||
char *msg = PQerrorMessage(pageserver_conn);
|
||||
char *msg = pchomp(PQerrorMessage(pageserver_conn));
|
||||
|
||||
pageserver_disconnect();
|
||||
neon_log(ERROR, "failed to send page request: %s", msg);
|
||||
@@ -239,29 +243,33 @@ pageserver_receive(void)
|
||||
PG_TRY();
|
||||
{
|
||||
/* read response */
|
||||
resp_buff.len = call_PQgetCopyData(&resp_buff.data);
|
||||
resp_buff.cursor = 0;
|
||||
int rc;
|
||||
|
||||
if (resp_buff.len < 0)
|
||||
rc = call_PQgetCopyData(&resp_buff.data);
|
||||
if (rc >= 0)
|
||||
{
|
||||
if (resp_buff.len == -1)
|
||||
resp_buff.len = rc;
|
||||
resp_buff.cursor = 0;
|
||||
resp = nm_unpack_response(&resp_buff);
|
||||
PQfreemem(resp_buff.data);
|
||||
|
||||
if (message_level_is_interesting(PageStoreTrace))
|
||||
{
|
||||
pageserver_disconnect();
|
||||
return NULL;
|
||||
char *msg = nm_to_string((NeonMessage *) resp);
|
||||
|
||||
neon_log(PageStoreTrace, "got response: %s", msg);
|
||||
pfree(msg);
|
||||
}
|
||||
else if (resp_buff.len == -2)
|
||||
neon_log(ERROR, "could not read COPY data: %s", PQerrorMessage(pageserver_conn));
|
||||
}
|
||||
resp = nm_unpack_response(&resp_buff);
|
||||
PQfreemem(resp_buff.data);
|
||||
|
||||
if (message_level_is_interesting(PageStoreTrace))
|
||||
else if (rc == -1)
|
||||
{
|
||||
char *msg = nm_to_string((NeonMessage *) resp);
|
||||
|
||||
neon_log(PageStoreTrace, "got response: %s", msg);
|
||||
pfree(msg);
|
||||
pageserver_disconnect();
|
||||
resp = NULL;
|
||||
}
|
||||
else if (rc == -2)
|
||||
neon_log(ERROR, "could not read COPY data: %s", PQerrorMessage(pageserver_conn));
|
||||
else
|
||||
neon_log(ERROR, "unexpected PQgetCopyData return value: %d", rc);
|
||||
}
|
||||
PG_CATCH();
|
||||
{
|
||||
|
||||
@@ -17,12 +17,14 @@ hashbrown = "0.12"
|
||||
hex = "0.4.3"
|
||||
hmac = "0.12.1"
|
||||
hyper = "0.14"
|
||||
hyper-tungstenite = "0.8.1"
|
||||
itertools = "0.10.3"
|
||||
md5 = "0.7.0"
|
||||
once_cell = "1.13.0"
|
||||
parking_lot = "0.12"
|
||||
pin-project-lite = "0.2.7"
|
||||
rand = "0.8.3"
|
||||
regex = "1.4.5"
|
||||
reqwest = { version = "0.11", default-features = false, features = [ "json", "rustls-tls" ] }
|
||||
routerify = "3"
|
||||
rustls = "0.20.0"
|
||||
@@ -36,10 +38,12 @@ thiserror = "1.0.30"
|
||||
tokio = { version = "1.17", features = ["macros"] }
|
||||
tokio-postgres = { git = "https://github.com/neondatabase/rust-postgres.git", rev="43e6db254a97fdecbce33d8bc0890accfd74495e" }
|
||||
tokio-rustls = "0.23.0"
|
||||
tls-listener = { version = "0.5.1", features = ["rustls", "hyper-h1"] }
|
||||
tracing = "0.1.36"
|
||||
tracing-subscriber = { version = "0.3", features = ["env-filter"] }
|
||||
url = "2.2.2"
|
||||
uuid = { version = "1.2", features = ["v4", "serde"] }
|
||||
webpki-roots = "0.22.5"
|
||||
x509-parser = "0.14"
|
||||
|
||||
metrics = { path = "../libs/metrics" }
|
||||
|
||||
@@ -149,7 +149,7 @@ impl BackendType<'_, ClientCredentials<'_>> {
|
||||
// If there's no project so far, that entails that client doesn't
|
||||
// support SNI or other means of passing the project name.
|
||||
// We now expect to see a very specific payload in the place of password.
|
||||
let fetch_magic_payload = async {
|
||||
let fetch_magic_payload = |client| async {
|
||||
warn!("project name not specified, resorting to the password hack auth flow");
|
||||
let payload = AuthFlow::new(client)
|
||||
.begin(auth::PasswordHack)
|
||||
@@ -161,10 +161,26 @@ impl BackendType<'_, ClientCredentials<'_>> {
|
||||
auth::Result::Ok(payload)
|
||||
};
|
||||
|
||||
// If we want to use cleartext password flow, we can read the password
|
||||
// from the client and pretend that it's a magic payload (PasswordHack hack).
|
||||
let fetch_plaintext_password = |client| async {
|
||||
info!("using cleartext password flow");
|
||||
let payload = AuthFlow::new(client)
|
||||
.begin(auth::CleartextPassword)
|
||||
.await?
|
||||
.authenticate()
|
||||
.await?;
|
||||
|
||||
auth::Result::Ok(auth::password_hack::PasswordHackPayload {
|
||||
project: String::new(),
|
||||
password: payload,
|
||||
})
|
||||
};
|
||||
|
||||
// TODO: find a proper way to merge those very similar blocks.
|
||||
let (mut node, payload) = match self {
|
||||
Console(endpoint, creds) if creds.project.is_none() => {
|
||||
let payload = fetch_magic_payload.await?;
|
||||
let payload = fetch_magic_payload(client).await?;
|
||||
|
||||
let mut creds = creds.as_ref();
|
||||
creds.project = Some(payload.project.as_str().into());
|
||||
@@ -174,8 +190,18 @@ impl BackendType<'_, ClientCredentials<'_>> {
|
||||
|
||||
(node, payload)
|
||||
}
|
||||
Console(endpoint, creds) if creds.use_cleartext_password_flow => {
|
||||
// This is a hack to allow cleartext password in secure connections (wss).
|
||||
let payload = fetch_plaintext_password(client).await?;
|
||||
let creds = creds.as_ref();
|
||||
let node = console::Api::new(endpoint, extra, &creds)
|
||||
.wake_compute()
|
||||
.await?;
|
||||
|
||||
(node, payload)
|
||||
}
|
||||
Postgres(endpoint, creds) if creds.project.is_none() => {
|
||||
let payload = fetch_magic_payload.await?;
|
||||
let payload = fetch_magic_payload(client).await?;
|
||||
|
||||
let mut creds = creds.as_ref();
|
||||
creds.project = Some(payload.project.as_str().into());
|
||||
|
||||
@@ -34,6 +34,9 @@ pub struct ClientCredentials<'a> {
|
||||
pub user: &'a str,
|
||||
pub dbname: &'a str,
|
||||
pub project: Option<Cow<'a, str>>,
|
||||
/// If `True`, we'll use the old cleartext password flow. This is used for
|
||||
/// websocket connections, which want to minimize the number of round trips.
|
||||
pub use_cleartext_password_flow: bool,
|
||||
}
|
||||
|
||||
impl ClientCredentials<'_> {
|
||||
@@ -50,6 +53,7 @@ impl<'a> ClientCredentials<'a> {
|
||||
user: self.user,
|
||||
dbname: self.dbname,
|
||||
project: self.project().map(Cow::Borrowed),
|
||||
use_cleartext_password_flow: self.use_cleartext_password_flow,
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -59,6 +63,7 @@ impl<'a> ClientCredentials<'a> {
|
||||
params: &'a StartupMessageParams,
|
||||
sni: Option<&str>,
|
||||
common_name: Option<&str>,
|
||||
use_cleartext_password_flow: bool,
|
||||
) -> Result<Self, ClientCredsParseError> {
|
||||
use ClientCredsParseError::*;
|
||||
|
||||
@@ -108,6 +113,7 @@ impl<'a> ClientCredentials<'a> {
|
||||
user = user,
|
||||
dbname = dbname,
|
||||
project = project.as_deref(),
|
||||
use_cleartext_password_flow = use_cleartext_password_flow,
|
||||
"credentials"
|
||||
);
|
||||
|
||||
@@ -115,6 +121,7 @@ impl<'a> ClientCredentials<'a> {
|
||||
user,
|
||||
dbname,
|
||||
project,
|
||||
use_cleartext_password_flow,
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -141,7 +148,7 @@ mod tests {
|
||||
let options = StartupMessageParams::new([("user", "john_doe")]);
|
||||
|
||||
// TODO: check that `creds.dbname` is None.
|
||||
let creds = ClientCredentials::parse(&options, None, None)?;
|
||||
let creds = ClientCredentials::parse(&options, None, None, false)?;
|
||||
assert_eq!(creds.user, "john_doe");
|
||||
|
||||
Ok(())
|
||||
@@ -151,7 +158,7 @@ mod tests {
|
||||
fn parse_missing_project() -> anyhow::Result<()> {
|
||||
let options = StartupMessageParams::new([("user", "john_doe"), ("database", "world")]);
|
||||
|
||||
let creds = ClientCredentials::parse(&options, None, None)?;
|
||||
let creds = ClientCredentials::parse(&options, None, None, false)?;
|
||||
assert_eq!(creds.user, "john_doe");
|
||||
assert_eq!(creds.dbname, "world");
|
||||
assert_eq!(creds.project, None);
|
||||
@@ -166,7 +173,7 @@ mod tests {
|
||||
let sni = Some("foo.localhost");
|
||||
let common_name = Some("localhost");
|
||||
|
||||
let creds = ClientCredentials::parse(&options, sni, common_name)?;
|
||||
let creds = ClientCredentials::parse(&options, sni, common_name, false)?;
|
||||
assert_eq!(creds.user, "john_doe");
|
||||
assert_eq!(creds.dbname, "world");
|
||||
assert_eq!(creds.project.as_deref(), Some("foo"));
|
||||
@@ -182,7 +189,7 @@ mod tests {
|
||||
("options", "-ckey=1 project=bar -c geqo=off"),
|
||||
]);
|
||||
|
||||
let creds = ClientCredentials::parse(&options, None, None)?;
|
||||
let creds = ClientCredentials::parse(&options, None, None, false)?;
|
||||
assert_eq!(creds.user, "john_doe");
|
||||
assert_eq!(creds.dbname, "world");
|
||||
assert_eq!(creds.project.as_deref(), Some("bar"));
|
||||
@@ -201,7 +208,7 @@ mod tests {
|
||||
let sni = Some("baz.localhost");
|
||||
let common_name = Some("localhost");
|
||||
|
||||
let creds = ClientCredentials::parse(&options, sni, common_name)?;
|
||||
let creds = ClientCredentials::parse(&options, sni, common_name, false)?;
|
||||
assert_eq!(creds.user, "john_doe");
|
||||
assert_eq!(creds.dbname, "world");
|
||||
assert_eq!(creds.project.as_deref(), Some("baz"));
|
||||
@@ -220,7 +227,8 @@ mod tests {
|
||||
let sni = Some("second.localhost");
|
||||
let common_name = Some("localhost");
|
||||
|
||||
let err = ClientCredentials::parse(&options, sni, common_name).expect_err("should fail");
|
||||
let err =
|
||||
ClientCredentials::parse(&options, sni, common_name, false).expect_err("should fail");
|
||||
match err {
|
||||
InconsistentProjectNames { domain, option } => {
|
||||
assert_eq!(option, "first");
|
||||
@@ -237,7 +245,8 @@ mod tests {
|
||||
let sni = Some("project.localhost");
|
||||
let common_name = Some("example.com");
|
||||
|
||||
let err = ClientCredentials::parse(&options, sni, common_name).expect_err("should fail");
|
||||
let err =
|
||||
ClientCredentials::parse(&options, sni, common_name, false).expect_err("should fail");
|
||||
match err {
|
||||
InconsistentSni { sni, cn } => {
|
||||
assert_eq!(sni, "project.localhost");
|
||||
|
||||
@@ -37,6 +37,17 @@ impl AuthMethod for PasswordHack {
|
||||
}
|
||||
}
|
||||
|
||||
/// Use clear-text password auth called `password` in docs
|
||||
/// <https://www.postgresql.org/docs/current/auth-password.html>
|
||||
pub struct CleartextPassword;
|
||||
|
||||
impl AuthMethod for CleartextPassword {
|
||||
#[inline(always)]
|
||||
fn first_message(&self) -> BeMessage<'_> {
|
||||
Be::AuthenticationCleartextPassword
|
||||
}
|
||||
}
|
||||
|
||||
/// This wrapper for [`PqStream`] performs client authentication.
|
||||
#[must_use]
|
||||
pub struct AuthFlow<'a, Stream, State> {
|
||||
@@ -86,6 +97,18 @@ impl<S: AsyncRead + AsyncWrite + Unpin> AuthFlow<'_, S, PasswordHack> {
|
||||
}
|
||||
}
|
||||
|
||||
impl<S: AsyncRead + AsyncWrite + Unpin> AuthFlow<'_, S, CleartextPassword> {
|
||||
/// Perform user authentication. Raise an error in case authentication failed.
|
||||
pub async fn authenticate(self) -> super::Result<Vec<u8>> {
|
||||
let msg = self.stream.read_password_message().await?;
|
||||
let password = msg
|
||||
.strip_suffix(&[0])
|
||||
.ok_or(AuthErrorImpl::MalformedPassword("missing terminator"))?;
|
||||
|
||||
Ok(password.to_vec())
|
||||
}
|
||||
}
|
||||
|
||||
/// Stream wrapper for handling [SCRAM](crate::scram) auth.
|
||||
impl<S: AsyncRead + AsyncWrite + Unpin> AuthFlow<'_, S, Scram<'_>> {
|
||||
/// Perform user authentication. Raise an error in case authentication failed.
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
pub mod server;
|
||||
pub mod websocket;
|
||||
|
||||
use crate::url::ApiUrl;
|
||||
|
||||
|
||||
263
proxy/src/http/websocket.rs
Normal file
263
proxy/src/http/websocket.rs
Normal file
@@ -0,0 +1,263 @@
|
||||
use bytes::{Buf, Bytes};
|
||||
use futures::{Sink, Stream, StreamExt};
|
||||
use hyper::server::accept::{self};
|
||||
use hyper::server::conn::AddrIncoming;
|
||||
use hyper::upgrade::Upgraded;
|
||||
use hyper::{Body, Request, Response, StatusCode};
|
||||
use hyper_tungstenite::{tungstenite, WebSocketStream};
|
||||
use hyper_tungstenite::{tungstenite::Message, HyperWebsocket};
|
||||
use pin_project_lite::pin_project;
|
||||
use tokio::net::TcpListener;
|
||||
|
||||
use std::convert::Infallible;
|
||||
use std::future::ready;
|
||||
use std::pin::Pin;
|
||||
use std::sync::Arc;
|
||||
use std::task::{Context, Poll};
|
||||
use tls_listener::TlsListener;
|
||||
|
||||
use tokio::io::{self, AsyncBufRead, AsyncRead, AsyncWrite, ReadBuf};
|
||||
|
||||
use tracing::{error, info, info_span, warn, Instrument};
|
||||
use utils::http::{error::ApiError, json::json_response};
|
||||
|
||||
use crate::cancellation::CancelMap;
|
||||
use crate::config::ProxyConfig;
|
||||
use crate::proxy::handle_ws_client;
|
||||
|
||||
pin_project! {
|
||||
/// This is a wrapper around a WebSocketStream that implements AsyncRead and AsyncWrite.
|
||||
pub struct WebSocketRW {
|
||||
#[pin]
|
||||
stream: WebSocketStream<Upgraded>,
|
||||
chunk: Option<bytes::Bytes>,
|
||||
}
|
||||
}
|
||||
|
||||
// FIXME: explain why this is safe or try to remove `unsafe impl`.
|
||||
unsafe impl Sync for WebSocketRW {}
|
||||
|
||||
impl WebSocketRW {
|
||||
pub fn new(stream: WebSocketStream<Upgraded>) -> Self {
|
||||
Self {
|
||||
stream,
|
||||
chunk: None,
|
||||
}
|
||||
}
|
||||
|
||||
fn has_chunk(&self) -> bool {
|
||||
if let Some(ref chunk) = self.chunk {
|
||||
chunk.remaining() > 0
|
||||
} else {
|
||||
false
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn ws_err_into(e: tungstenite::Error) -> io::Error {
|
||||
io::Error::new(io::ErrorKind::Other, e.to_string())
|
||||
}
|
||||
|
||||
impl AsyncWrite for WebSocketRW {
|
||||
fn poll_write(
|
||||
self: Pin<&mut Self>,
|
||||
cx: &mut Context<'_>,
|
||||
buf: &[u8],
|
||||
) -> Poll<Result<usize, io::Error>> {
|
||||
let mut this = self.project();
|
||||
match this.stream.as_mut().poll_ready(cx) {
|
||||
Poll::Ready(Ok(())) => {
|
||||
if let Err(e) = this
|
||||
.stream
|
||||
.as_mut()
|
||||
.start_send(Message::Binary(buf.to_vec()))
|
||||
{
|
||||
Poll::Ready(Err(ws_err_into(e)))
|
||||
} else {
|
||||
Poll::Ready(Ok(buf.len()))
|
||||
}
|
||||
}
|
||||
Poll::Ready(Err(e)) => Poll::Ready(Err(ws_err_into(e))),
|
||||
Poll::Pending => {
|
||||
cx.waker().wake_by_ref();
|
||||
Poll::Pending
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), io::Error>> {
|
||||
self.project().stream.poll_flush(cx).map_err(ws_err_into)
|
||||
}
|
||||
|
||||
fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), io::Error>> {
|
||||
self.project().stream.poll_close(cx).map_err(ws_err_into)
|
||||
}
|
||||
}
|
||||
|
||||
impl AsyncRead for WebSocketRW {
|
||||
fn poll_read(
|
||||
mut self: Pin<&mut Self>,
|
||||
cx: &mut Context<'_>,
|
||||
buf: &mut ReadBuf<'_>,
|
||||
) -> Poll<io::Result<()>> {
|
||||
if buf.remaining() == 0 {
|
||||
return Poll::Ready(Ok(()));
|
||||
}
|
||||
|
||||
let inner_buf = match self.as_mut().poll_fill_buf(cx) {
|
||||
Poll::Ready(Ok(buf)) => buf,
|
||||
Poll::Ready(Err(err)) => return Poll::Ready(Err(err)),
|
||||
Poll::Pending => return Poll::Pending,
|
||||
};
|
||||
let len = std::cmp::min(inner_buf.len(), buf.remaining());
|
||||
buf.put_slice(&inner_buf[..len]);
|
||||
|
||||
self.consume(len);
|
||||
Poll::Ready(Ok(()))
|
||||
}
|
||||
}
|
||||
|
||||
impl AsyncBufRead for WebSocketRW {
|
||||
fn poll_fill_buf(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<&[u8]>> {
|
||||
loop {
|
||||
if self.as_mut().has_chunk() {
|
||||
let buf = self.project().chunk.as_ref().unwrap().chunk();
|
||||
return Poll::Ready(Ok(buf));
|
||||
} else {
|
||||
match self.as_mut().project().stream.poll_next(cx) {
|
||||
Poll::Ready(Some(Ok(message))) => match message {
|
||||
Message::Text(_) => {}
|
||||
Message::Binary(chunk) => {
|
||||
*self.as_mut().project().chunk = Some(Bytes::from(chunk));
|
||||
}
|
||||
Message::Ping(_) => {
|
||||
// No need to send a reply: tungstenite takes care of this for you.
|
||||
}
|
||||
Message::Pong(_) => {}
|
||||
Message::Close(_) => {
|
||||
// No need to send a reply: tungstenite takes care of this for you.
|
||||
return Poll::Ready(Ok(&[]));
|
||||
}
|
||||
Message::Frame(_) => {
|
||||
unreachable!();
|
||||
}
|
||||
},
|
||||
Poll::Ready(Some(Err(err))) => return Poll::Ready(Err(ws_err_into(err))),
|
||||
Poll::Ready(None) => return Poll::Ready(Ok(&[])),
|
||||
Poll::Pending => return Poll::Pending,
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn consume(self: Pin<&mut Self>, amt: usize) {
|
||||
if amt > 0 {
|
||||
self.project()
|
||||
.chunk
|
||||
.as_mut()
|
||||
.expect("No chunk present")
|
||||
.advance(amt);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
async fn serve_websocket(
|
||||
websocket: HyperWebsocket,
|
||||
config: &ProxyConfig,
|
||||
cancel_map: &CancelMap,
|
||||
session_id: uuid::Uuid,
|
||||
hostname: Option<String>,
|
||||
) -> anyhow::Result<()> {
|
||||
let websocket = websocket.await?;
|
||||
handle_ws_client(
|
||||
config,
|
||||
cancel_map,
|
||||
session_id,
|
||||
WebSocketRW::new(websocket),
|
||||
hostname,
|
||||
)
|
||||
.await?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn ws_handler(
|
||||
mut request: Request<Body>,
|
||||
config: &'static ProxyConfig,
|
||||
cancel_map: Arc<CancelMap>,
|
||||
session_id: uuid::Uuid,
|
||||
) -> Result<Response<Body>, ApiError> {
|
||||
let host = request
|
||||
.headers()
|
||||
.get("host")
|
||||
.and_then(|h| h.to_str().ok())
|
||||
.and_then(|h| h.split(':').next())
|
||||
.map(|s| s.to_string());
|
||||
|
||||
// Check if the request is a websocket upgrade request.
|
||||
if hyper_tungstenite::is_upgrade_request(&request) {
|
||||
let (response, websocket) = hyper_tungstenite::upgrade(&mut request, None)
|
||||
.map_err(|e| ApiError::BadRequest(e.into()))?;
|
||||
|
||||
tokio::spawn(async move {
|
||||
if let Err(e) = serve_websocket(websocket, config, &cancel_map, session_id, host).await
|
||||
{
|
||||
error!("error in websocket connection: {:?}", e);
|
||||
}
|
||||
});
|
||||
|
||||
// Return the response so the spawned future can continue.
|
||||
Ok(response)
|
||||
} else {
|
||||
json_response(StatusCode::OK, "Connect with a websocket client")
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn task_main(
|
||||
ws_listener: TcpListener,
|
||||
config: &'static ProxyConfig,
|
||||
) -> anyhow::Result<()> {
|
||||
scopeguard::defer! {
|
||||
info!("websocket server has shut down");
|
||||
}
|
||||
|
||||
let tls_config = config.tls_config.as_ref().map(|cfg| cfg.to_server_config());
|
||||
let tls_acceptor: tokio_rustls::TlsAcceptor = match tls_config {
|
||||
Some(config) => config.into(),
|
||||
None => {
|
||||
warn!("TLS config is missing, WebSocket Secure server will not be started");
|
||||
return Ok(());
|
||||
}
|
||||
};
|
||||
|
||||
let addr_incoming = AddrIncoming::from_listener(ws_listener)?;
|
||||
|
||||
let tls_listener = TlsListener::new(tls_acceptor, addr_incoming).filter(|conn| {
|
||||
if let Err(err) = conn {
|
||||
error!("failed to accept TLS connection for websockets: {:?}", err);
|
||||
ready(false)
|
||||
} else {
|
||||
ready(true)
|
||||
}
|
||||
});
|
||||
|
||||
let make_svc = hyper::service::make_service_fn(|_stream| async move {
|
||||
Ok::<_, Infallible>(hyper::service::service_fn(
|
||||
move |req: Request<Body>| async move {
|
||||
let cancel_map = Arc::new(CancelMap::default());
|
||||
let session_id = uuid::Uuid::new_v4();
|
||||
ws_handler(req, config, cancel_map, session_id)
|
||||
.instrument(info_span!(
|
||||
"ws-client",
|
||||
session = format_args!("{session_id}")
|
||||
))
|
||||
.await
|
||||
},
|
||||
))
|
||||
});
|
||||
|
||||
hyper::Server::builder(accept::from_stream(tls_listener))
|
||||
.serve(make_svc)
|
||||
.await?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
@@ -110,12 +110,23 @@ async fn main() -> anyhow::Result<()> {
|
||||
info!("Starting proxy on {proxy_address}");
|
||||
let proxy_listener = TcpListener::bind(proxy_address).await?;
|
||||
|
||||
let tasks = [
|
||||
let mut tasks = vec![
|
||||
tokio::spawn(http::server::task_main(http_listener)),
|
||||
tokio::spawn(proxy::task_main(config, proxy_listener)),
|
||||
tokio::task::spawn_blocking(move || mgmt::thread_main(mgmt_listener)),
|
||||
]
|
||||
.map(flatten_err);
|
||||
];
|
||||
|
||||
if let Some(wss_address) = arg_matches.get_one::<String>("wss") {
|
||||
let wss_address: SocketAddr = wss_address.parse()?;
|
||||
info!("Starting wss on {}", wss_address);
|
||||
let wss_listener = TcpListener::bind(wss_address).await?;
|
||||
tasks.push(tokio::spawn(http::websocket::task_main(
|
||||
wss_listener,
|
||||
config,
|
||||
)));
|
||||
}
|
||||
|
||||
let tasks = tasks.into_iter().map(flatten_err);
|
||||
|
||||
set_build_info_metric(GIT_VERSION);
|
||||
// This will block until all tasks have completed.
|
||||
@@ -155,6 +166,11 @@ fn cli() -> clap::Command {
|
||||
.help("listen for incoming http connections (metrics, etc) on ip:port")
|
||||
.default_value("127.0.0.1:7001"),
|
||||
)
|
||||
.arg(
|
||||
Arg::new("wss")
|
||||
.long("wss")
|
||||
.help("listen for incoming wss connections on ip:port"),
|
||||
)
|
||||
.arg(
|
||||
Arg::new("uri")
|
||||
.short('u')
|
||||
|
||||
@@ -9,7 +9,10 @@ use std::{
|
||||
thread,
|
||||
};
|
||||
use tracing::{error, info, info_span};
|
||||
use utils::postgres_backend::{self, AuthType, PostgresBackend};
|
||||
use utils::{
|
||||
postgres_backend::{self, AuthType, PostgresBackend},
|
||||
postgres_backend_async::QueryError,
|
||||
};
|
||||
|
||||
/// Console management API listener thread.
|
||||
/// It spawns console response handlers needed for the link auth.
|
||||
@@ -47,7 +50,7 @@ pub fn thread_main(listener: TcpListener) -> anyhow::Result<()> {
|
||||
}
|
||||
}
|
||||
|
||||
fn handle_connection(socket: TcpStream) -> anyhow::Result<()> {
|
||||
fn handle_connection(socket: TcpStream) -> Result<(), QueryError> {
|
||||
let pgbackend = PostgresBackend::new(socket, AuthType::Trust, None, true)?;
|
||||
pgbackend.run(&mut MgmtHandler)
|
||||
}
|
||||
@@ -58,7 +61,7 @@ pub type ComputeReady = Result<DatabaseInfo, String>;
|
||||
// TODO: replace with an http-based protocol.
|
||||
struct MgmtHandler;
|
||||
impl postgres_backend::Handler for MgmtHandler {
|
||||
fn process_query(&mut self, pgb: &mut PostgresBackend, query: &str) -> anyhow::Result<()> {
|
||||
fn process_query(&mut self, pgb: &mut PostgresBackend, query: &str) -> Result<(), QueryError> {
|
||||
try_process_query(pgb, query).map_err(|e| {
|
||||
error!("failed to process response: {e:?}");
|
||||
e
|
||||
@@ -66,8 +69,8 @@ impl postgres_backend::Handler for MgmtHandler {
|
||||
}
|
||||
}
|
||||
|
||||
fn try_process_query(pgb: &mut PostgresBackend, query: &str) -> anyhow::Result<()> {
|
||||
let resp: KickSession = serde_json::from_str(query)?;
|
||||
fn try_process_query(pgb: &mut PostgresBackend, query: &str) -> Result<(), QueryError> {
|
||||
let resp: KickSession = serde_json::from_str(query).context("Failed to parse query as json")?;
|
||||
|
||||
let span = info_span!("event", session_id = resp.session_id);
|
||||
let _enter = span.enter();
|
||||
@@ -81,7 +84,7 @@ fn try_process_query(pgb: &mut PostgresBackend, query: &str) -> anyhow::Result<(
|
||||
}
|
||||
Err(e) => {
|
||||
error!("failed to deliver response to per-client task");
|
||||
pgb.write_message(&BeMessage::ErrorResponse(&e.to_string()))?;
|
||||
pgb.write_message(&BeMessage::ErrorResponse(&e.to_string(), None))?;
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -82,6 +82,47 @@ pub async fn task_main(
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn handle_ws_client(
|
||||
config: &ProxyConfig,
|
||||
cancel_map: &CancelMap,
|
||||
session_id: uuid::Uuid,
|
||||
stream: impl AsyncRead + AsyncWrite + Unpin + Send,
|
||||
hostname: Option<String>,
|
||||
) -> anyhow::Result<()> {
|
||||
// The `closed` counter will increase when this future is destroyed.
|
||||
NUM_CONNECTIONS_ACCEPTED_COUNTER.inc();
|
||||
scopeguard::defer! {
|
||||
NUM_CONNECTIONS_CLOSED_COUNTER.inc();
|
||||
}
|
||||
|
||||
let tls = config.tls_config.as_ref();
|
||||
let hostname = hostname.as_deref();
|
||||
|
||||
// TLS is None here, because the connection is already encrypted.
|
||||
let do_handshake = handshake(stream, None, cancel_map).instrument(info_span!("handshake"));
|
||||
let (mut stream, params) = match do_handshake.await? {
|
||||
Some(x) => x,
|
||||
None => return Ok(()), // it's a cancellation request
|
||||
};
|
||||
|
||||
// Extract credentials which we're going to use for auth.
|
||||
let creds = {
|
||||
let common_name = tls.and_then(|tls| tls.common_name.as_deref());
|
||||
let result = config
|
||||
.auth_backend
|
||||
.as_ref()
|
||||
.map(|_| auth::ClientCredentials::parse(¶ms, hostname, common_name, true))
|
||||
.transpose();
|
||||
|
||||
async { result }.or_else(|e| stream.throw_error(e)).await?
|
||||
};
|
||||
|
||||
let client = Client::new(stream, creds, ¶ms, session_id);
|
||||
cancel_map
|
||||
.with_session(|session| client.connect_to_db(session))
|
||||
.await
|
||||
}
|
||||
|
||||
async fn handle_client(
|
||||
config: &ProxyConfig,
|
||||
cancel_map: &CancelMap,
|
||||
@@ -108,7 +149,7 @@ async fn handle_client(
|
||||
let result = config
|
||||
.auth_backend
|
||||
.as_ref()
|
||||
.map(|_| auth::ClientCredentials::parse(¶ms, sni, common_name))
|
||||
.map(|_| auth::ClientCredentials::parse(¶ms, sni, common_name, false))
|
||||
.transpose();
|
||||
|
||||
async { result }.or_else(|e| stream.throw_error(e)).await?
|
||||
|
||||
@@ -2,7 +2,7 @@ use crate::error::UserFacingError;
|
||||
use anyhow::bail;
|
||||
use bytes::BytesMut;
|
||||
use pin_project_lite::pin_project;
|
||||
use pq_proto::{BeMessage, FeMessage, FeStartupPacket};
|
||||
use pq_proto::{BeMessage, FeMessage, FeStartupPacket, ProtocolError};
|
||||
use rustls::ServerConfig;
|
||||
use std::pin::Pin;
|
||||
use std::sync::Arc;
|
||||
@@ -47,18 +47,13 @@ fn err_connection() -> io::Error {
|
||||
io::Error::new(io::ErrorKind::ConnectionAborted, "connection is lost")
|
||||
}
|
||||
|
||||
// TODO: change error type of `FeMessage::read_fut`
|
||||
fn from_anyhow(e: anyhow::Error) -> io::Error {
|
||||
io::Error::new(io::ErrorKind::Other, e.to_string())
|
||||
}
|
||||
|
||||
impl<S: AsyncRead + Unpin> PqStream<S> {
|
||||
/// Receive [`FeStartupPacket`], which is a first packet sent by a client.
|
||||
pub async fn read_startup_packet(&mut self) -> io::Result<FeStartupPacket> {
|
||||
// TODO: `FeStartupPacket::read_fut` should return `FeStartupPacket`
|
||||
let msg = FeStartupPacket::read_fut(&mut self.stream)
|
||||
.await
|
||||
.map_err(from_anyhow)?
|
||||
.map_err(ProtocolError::into_io_error)?
|
||||
.ok_or_else(err_connection)?;
|
||||
|
||||
match msg {
|
||||
@@ -80,7 +75,7 @@ impl<S: AsyncRead + Unpin> PqStream<S> {
|
||||
async fn read_message(&mut self) -> io::Result<FeMessage> {
|
||||
FeMessage::read_fut(&mut self.stream)
|
||||
.await
|
||||
.map_err(from_anyhow)?
|
||||
.map_err(ProtocolError::into_io_error)?
|
||||
.ok_or_else(err_connection)
|
||||
}
|
||||
}
|
||||
@@ -112,7 +107,8 @@ impl<S: AsyncWrite + Unpin> PqStream<S> {
|
||||
/// This method exists due to `&str` not implementing `Into<anyhow::Error>`.
|
||||
pub async fn throw_error_str<T>(&mut self, error: &'static str) -> anyhow::Result<T> {
|
||||
tracing::info!("forwarding error to user: {error}");
|
||||
self.write_message(&BeMessage::ErrorResponse(error)).await?;
|
||||
self.write_message(&BeMessage::ErrorResponse(error, None))
|
||||
.await?;
|
||||
bail!(error)
|
||||
}
|
||||
|
||||
@@ -124,7 +120,8 @@ impl<S: AsyncWrite + Unpin> PqStream<S> {
|
||||
{
|
||||
let msg = error.to_string_client();
|
||||
tracing::info!("forwarding error to user: {msg}");
|
||||
self.write_message(&BeMessage::ErrorResponse(&msg)).await?;
|
||||
self.write_message(&BeMessage::ErrorResponse(&msg, None))
|
||||
.await?;
|
||||
bail!(error)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -4,7 +4,7 @@
|
||||
# version, we can consider updating.
|
||||
# See https://tracker.debian.org/pkg/rustc for more details on Debian rustc package,
|
||||
# we use "unstable" version number as the highest version used in the project by default.
|
||||
channel = "1.62.1"
|
||||
channel = "1.66.1"
|
||||
profile = "default"
|
||||
# The default profile includes rustc, rust-std, cargo, rust-docs, rustfmt and clippy.
|
||||
# https://rust-lang.github.io/rustup/concepts/profiles.html
|
||||
|
||||
@@ -14,12 +14,14 @@ clap = { version = "4.0", features = ["derive"] }
|
||||
const_format = "0.2.21"
|
||||
crc32c = "0.6.0"
|
||||
fs2 = "0.4.3"
|
||||
futures = "0.3"
|
||||
git-version = "0.3.5"
|
||||
hex = "0.4.3"
|
||||
humantime = "2.1.0"
|
||||
hyper = "0.14"
|
||||
nix = "0.25"
|
||||
once_cell = "1.13.0"
|
||||
pin-project-lite = "0.2"
|
||||
parking_lot = "0.12.1"
|
||||
postgres = { git = "https://github.com/neondatabase/rust-postgres.git", rev="43e6db254a97fdecbce33d8bc0890accfd74495e" }
|
||||
postgres-protocol = { git = "https://github.com/neondatabase/rust-postgres.git", rev="43e6db254a97fdecbce33d8bc0890accfd74495e" }
|
||||
|
||||
@@ -228,24 +228,20 @@ fn start_safekeeper(conf: SafeKeeperConf) -> Result<()> {
|
||||
|
||||
let conf_cloned = conf.clone();
|
||||
let safekeeper_thread = thread::Builder::new()
|
||||
.name("safekeeper thread".into())
|
||||
.spawn(|| {
|
||||
if let Err(e) = wal_service::thread_main(conf_cloned, pg_listener) {
|
||||
info!("safekeeper thread terminated: {e}");
|
||||
}
|
||||
})
|
||||
.name("WAL service thread".into())
|
||||
.spawn(|| wal_service::thread_main(conf_cloned, pg_listener))
|
||||
.unwrap();
|
||||
|
||||
threads.push(safekeeper_thread);
|
||||
|
||||
let conf_ = conf.clone();
|
||||
threads.push(
|
||||
thread::Builder::new()
|
||||
.name("broker thread".into())
|
||||
.spawn(|| {
|
||||
broker::thread_main(conf_);
|
||||
})?,
|
||||
);
|
||||
// threads.push(
|
||||
// thread::Builder::new()
|
||||
// .name("broker thread".into())
|
||||
// .spawn(|| {
|
||||
// broker::thread_main(conf_);
|
||||
// })?,
|
||||
// );
|
||||
|
||||
let conf_ = conf.clone();
|
||||
threads.push(
|
||||
|
||||
@@ -1,26 +1,23 @@
|
||||
//! Part of Safekeeper pretending to be Postgres, i.e. handling Postgres
|
||||
//! protocol commands.
|
||||
|
||||
use anyhow::{bail, Context};
|
||||
use std::str;
|
||||
use tracing::{info, info_span, Instrument};
|
||||
|
||||
use crate::auth::check_permission;
|
||||
use crate::json_ctrl::{handle_json_ctrl, AppendLogicalMessage};
|
||||
use crate::receive_wal::ReceiveWalConn;
|
||||
|
||||
use crate::send_wal::ReplicationConn;
|
||||
|
||||
use crate::{GlobalTimelines, SafeKeeperConf};
|
||||
use anyhow::{bail, ensure, Context, Result};
|
||||
|
||||
use postgres_ffi::PG_TLI;
|
||||
use regex::Regex;
|
||||
|
||||
use pq_proto::{BeMessage, FeStartupPacket, RowDescriptor, INT4_OID, TEXT_OID};
|
||||
use std::str;
|
||||
use tracing::info;
|
||||
use regex::Regex;
|
||||
use utils::auth::{Claims, Scope};
|
||||
use utils::postgres_backend_async::QueryError;
|
||||
use utils::{
|
||||
id::{TenantId, TenantTimelineId, TimelineId},
|
||||
lsn::Lsn,
|
||||
postgres_backend::{self, PostgresBackend},
|
||||
postgres_backend_async::{self, PostgresBackend},
|
||||
};
|
||||
|
||||
/// Safekeeper handler of postgres commands
|
||||
@@ -40,9 +37,11 @@ enum SafekeeperPostgresCommand {
|
||||
StartReplication { start_lsn: Lsn },
|
||||
IdentifySystem,
|
||||
JSONCtrl { cmd: AppendLogicalMessage },
|
||||
Show { guc: String },
|
||||
}
|
||||
|
||||
fn parse_cmd(cmd: &str) -> Result<SafekeeperPostgresCommand> {
|
||||
fn parse_cmd(cmd: &str) -> anyhow::Result<SafekeeperPostgresCommand> {
|
||||
let cmd_lowercase = cmd.to_ascii_lowercase();
|
||||
if cmd.starts_with("START_WAL_PUSH") {
|
||||
Ok(SafekeeperPostgresCommand::StartWalPush)
|
||||
} else if cmd.starts_with("START_REPLICATION") {
|
||||
@@ -52,7 +51,7 @@ fn parse_cmd(cmd: &str) -> Result<SafekeeperPostgresCommand> {
|
||||
let start_lsn = caps
|
||||
.next()
|
||||
.map(|cap| cap[1].parse::<Lsn>())
|
||||
.context("failed to parse start LSN from START_REPLICATION command")??;
|
||||
.context("parse start LSN from START_REPLICATION command")??;
|
||||
Ok(SafekeeperPostgresCommand::StartReplication { start_lsn })
|
||||
} else if cmd.starts_with("IDENTIFY_SYSTEM") {
|
||||
Ok(SafekeeperPostgresCommand::IdentifySystem)
|
||||
@@ -61,14 +60,27 @@ fn parse_cmd(cmd: &str) -> Result<SafekeeperPostgresCommand> {
|
||||
Ok(SafekeeperPostgresCommand::JSONCtrl {
|
||||
cmd: serde_json::from_str(cmd)?,
|
||||
})
|
||||
} else if cmd_lowercase.starts_with("show") {
|
||||
let re = Regex::new(r"show ((?:[[:alpha:]]|_)+)").unwrap();
|
||||
let mut caps = re.captures_iter(&cmd_lowercase);
|
||||
let guc = caps
|
||||
.next()
|
||||
.map(|cap| cap[1].parse::<String>())
|
||||
.context("parse guc in SHOW command")??;
|
||||
Ok(SafekeeperPostgresCommand::Show { guc })
|
||||
} else {
|
||||
bail!("unsupported command {}", cmd);
|
||||
anyhow::bail!("unsupported command {cmd}");
|
||||
}
|
||||
}
|
||||
|
||||
impl postgres_backend::Handler for SafekeeperPostgresHandler {
|
||||
#[async_trait::async_trait]
|
||||
impl postgres_backend_async::Handler for SafekeeperPostgresHandler {
|
||||
// tenant_id and timeline_id are passed in connection string params
|
||||
fn startup(&mut self, _pgb: &mut PostgresBackend, sm: &FeStartupPacket) -> Result<()> {
|
||||
fn startup(
|
||||
&mut self,
|
||||
_pgb: &mut PostgresBackend,
|
||||
sm: &FeStartupPacket,
|
||||
) -> Result<(), QueryError> {
|
||||
if let FeStartupPacket::StartupMessage { params, .. } = sm {
|
||||
if let Some(options) = params.options_raw() {
|
||||
for opt in options {
|
||||
@@ -77,10 +89,14 @@ impl postgres_backend::Handler for SafekeeperPostgresHandler {
|
||||
// https://github.com/neondatabase/neon/pull/2433#discussion_r970005064
|
||||
match opt.split_once('=') {
|
||||
Some(("ztenantid", value)) | Some(("tenant_id", value)) => {
|
||||
self.tenant_id = Some(value.parse()?);
|
||||
self.tenant_id = Some(value.parse().with_context(|| {
|
||||
format!("Failed to parse {value} as tenant id")
|
||||
})?);
|
||||
}
|
||||
Some(("ztimelineid", value)) | Some(("timeline_id", value)) => {
|
||||
self.timeline_id = Some(value.parse()?);
|
||||
self.timeline_id = Some(value.parse().with_context(|| {
|
||||
format!("Failed to parse {value} as timeline id")
|
||||
})?);
|
||||
}
|
||||
_ => continue,
|
||||
}
|
||||
@@ -93,7 +109,9 @@ impl postgres_backend::Handler for SafekeeperPostgresHandler {
|
||||
|
||||
Ok(())
|
||||
} else {
|
||||
bail!("Safekeeper received unexpected initial message: {:?}", sm);
|
||||
Err(QueryError::Other(anyhow::anyhow!(
|
||||
"Safekeeper received unexpected initial message: {sm:?}"
|
||||
)))
|
||||
}
|
||||
}
|
||||
|
||||
@@ -101,7 +119,7 @@ impl postgres_backend::Handler for SafekeeperPostgresHandler {
|
||||
&mut self,
|
||||
_pgb: &mut PostgresBackend,
|
||||
jwt_response: &[u8],
|
||||
) -> anyhow::Result<()> {
|
||||
) -> Result<(), QueryError> {
|
||||
// this unwrap is never triggered, because check_auth_jwt only called when auth_type is NeonJWT
|
||||
// which requires auth to be present
|
||||
let data = self
|
||||
@@ -109,13 +127,12 @@ impl postgres_backend::Handler for SafekeeperPostgresHandler {
|
||||
.auth
|
||||
.as_ref()
|
||||
.unwrap()
|
||||
.decode(str::from_utf8(jwt_response)?)?;
|
||||
.decode(str::from_utf8(jwt_response).context("jwt response is not UTF-8")?)?;
|
||||
|
||||
if matches!(data.claims.scope, Scope::Tenant) {
|
||||
ensure!(
|
||||
data.claims.tenant_id.is_some(),
|
||||
if matches!(data.claims.scope, Scope::Tenant) && data.claims.tenant_id.is_none() {
|
||||
return Err(QueryError::Other(anyhow::anyhow!(
|
||||
"jwt token scope is Tenant, but tenant id is missing"
|
||||
)
|
||||
)));
|
||||
}
|
||||
|
||||
info!(
|
||||
@@ -127,15 +144,21 @@ impl postgres_backend::Handler for SafekeeperPostgresHandler {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn process_query(&mut self, pgb: &mut PostgresBackend, query_string: &str) -> Result<()> {
|
||||
async fn process_query(
|
||||
&mut self,
|
||||
pgb: &mut PostgresBackend,
|
||||
query_string: &str,
|
||||
) -> Result<(), QueryError> {
|
||||
if query_string
|
||||
.to_ascii_lowercase()
|
||||
.starts_with("set datestyle to ")
|
||||
{
|
||||
// important for debug because psycopg2 executes "SET datestyle TO 'ISO'" on connect
|
||||
pgb.write_message(&BeMessage::CommandComplete(b"SELECT 1"))?;
|
||||
pgb.write_message_flush(&BeMessage::CommandComplete(b"SELECT 1"))
|
||||
.await?;
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
let cmd = parse_cmd(query_string)?;
|
||||
|
||||
info!(
|
||||
@@ -147,20 +170,35 @@ impl postgres_backend::Handler for SafekeeperPostgresHandler {
|
||||
let timeline_id = self.timeline_id.context("timelineid is required")?;
|
||||
self.check_permission(Some(tenant_id))?;
|
||||
self.ttid = TenantTimelineId::new(tenant_id, timeline_id);
|
||||
let span_ttid = self.ttid; // satisfy borrow checker
|
||||
|
||||
match cmd {
|
||||
SafekeeperPostgresCommand::StartWalPush => ReceiveWalConn::new(pgb).run(self),
|
||||
let res = match cmd {
|
||||
// SafekeeperPostgresCommand::StartWalPush => ReceiveWalConn::new(pgb).run(self),
|
||||
SafekeeperPostgresCommand::StartWalPush => Ok(()),
|
||||
SafekeeperPostgresCommand::StartReplication { start_lsn } => {
|
||||
ReplicationConn::new(pgb).run(self, pgb, start_lsn)
|
||||
self.handle_start_replication(pgb, start_lsn)
|
||||
.instrument(info_span!("WAL sender", ttid = %span_ttid))
|
||||
.await
|
||||
}
|
||||
SafekeeperPostgresCommand::IdentifySystem => self.handle_identify_system(pgb),
|
||||
SafekeeperPostgresCommand::JSONCtrl { ref cmd } => handle_json_ctrl(self, pgb, cmd),
|
||||
}
|
||||
.context(format!(
|
||||
"Failed to process query for timeline {timeline_id}"
|
||||
))?;
|
||||
SafekeeperPostgresCommand::IdentifySystem => self.handle_identify_system(pgb).await,
|
||||
SafekeeperPostgresCommand::Show { guc } => self.handle_show(guc, pgb).await,
|
||||
SafekeeperPostgresCommand::JSONCtrl { ref cmd } => {
|
||||
handle_json_ctrl(self, pgb, cmd).await
|
||||
}
|
||||
_ => unreachable!(),
|
||||
};
|
||||
|
||||
Ok(())
|
||||
match res {
|
||||
Ok(()) => Ok(()),
|
||||
Err(QueryError::Disconnected(connection_error)) => {
|
||||
info!("Timeline {tenant_id}/{timeline_id} query failed with connection error: {connection_error}");
|
||||
Err(QueryError::Disconnected(connection_error))
|
||||
}
|
||||
Err(QueryError::Other(e)) => Err(QueryError::Other(e.context(format!(
|
||||
"Failed to process query for timeline {}",
|
||||
self.ttid
|
||||
)))),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -178,7 +216,7 @@ impl SafekeeperPostgresHandler {
|
||||
|
||||
// when accessing management api supply None as an argument
|
||||
// when using to authorize tenant pass corresponding tenant id
|
||||
fn check_permission(&self, tenant_id: Option<TenantId>) -> Result<()> {
|
||||
fn check_permission(&self, tenant_id: Option<TenantId>) -> anyhow::Result<()> {
|
||||
if self.conf.auth.is_none() {
|
||||
// auth is set to Trust, nothing to check so just return ok
|
||||
return Ok(());
|
||||
@@ -196,7 +234,10 @@ impl SafekeeperPostgresHandler {
|
||||
///
|
||||
/// Handle IDENTIFY_SYSTEM replication command
|
||||
///
|
||||
fn handle_identify_system(&mut self, pgb: &mut PostgresBackend) -> Result<()> {
|
||||
async fn handle_identify_system(
|
||||
&mut self,
|
||||
pgb: &mut PostgresBackend,
|
||||
) -> Result<(), QueryError> {
|
||||
let tli = GlobalTimelines::get(self.ttid)?;
|
||||
|
||||
let lsn = if self.is_walproposer_recovery() {
|
||||
@@ -214,7 +255,7 @@ impl SafekeeperPostgresHandler {
|
||||
let tli_bytes = tli.as_bytes();
|
||||
let sysid_bytes = sysid.as_bytes();
|
||||
|
||||
pgb.write_message_noflush(&BeMessage::RowDescription(&[
|
||||
pgb.write_message(&BeMessage::RowDescription(&[
|
||||
RowDescriptor {
|
||||
name: b"systemid",
|
||||
typoid: TEXT_OID,
|
||||
@@ -240,13 +281,48 @@ impl SafekeeperPostgresHandler {
|
||||
..Default::default()
|
||||
},
|
||||
]))?
|
||||
.write_message_noflush(&BeMessage::DataRow(&[
|
||||
.write_message(&BeMessage::DataRow(&[
|
||||
Some(sysid_bytes),
|
||||
Some(tli_bytes),
|
||||
Some(lsn_bytes),
|
||||
None,
|
||||
]))?
|
||||
.write_message(&BeMessage::CommandComplete(b"IDENTIFY_SYSTEM"))?;
|
||||
.write_message_flush(&BeMessage::CommandComplete(b"IDENTIFY_SYSTEM"))
|
||||
.await?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn handle_show(
|
||||
&mut self,
|
||||
guc: String,
|
||||
pgb: &mut PostgresBackend,
|
||||
) -> Result<(), QueryError> {
|
||||
match guc.as_str() {
|
||||
// pg_receivewal wants it
|
||||
"data_directory_mode" => {
|
||||
pgb.write_message(&BeMessage::RowDescription(&[RowDescriptor::int8_col(
|
||||
b"data_directory_mode",
|
||||
)]))?
|
||||
// xxx we could return real one, not just 0700
|
||||
.write_message(&BeMessage::DataRow(&[Some(0700.to_string().as_bytes())]))?
|
||||
.write_message(&BeMessage::CommandComplete(b"SELECT 1"))?;
|
||||
}
|
||||
// pg_receivewal wants it
|
||||
"wal_segment_size" => {
|
||||
let tli = GlobalTimelines::get(self.ttid)?;
|
||||
let wal_seg_size = tli.get_state().1.server.wal_seg_size;
|
||||
let wal_seg_size_mb = (wal_seg_size / 1024 / 1024).to_string() + "MB";
|
||||
|
||||
pgb.write_message(&BeMessage::RowDescription(&[RowDescriptor::text_col(
|
||||
b"wal_segment_size",
|
||||
)]))?
|
||||
.write_message(&BeMessage::DataRow(&[Some(wal_seg_size_mb.as_bytes())]))?
|
||||
.write_message(&BeMessage::CommandComplete(b"SELECT 1"))?;
|
||||
}
|
||||
_ => {
|
||||
return Err(anyhow::anyhow!("SHOW of unknown setting").into());
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
|
||||
@@ -8,11 +8,14 @@ use serde::Serialize;
|
||||
use serde::Serializer;
|
||||
use std::collections::{HashMap, HashSet};
|
||||
use std::fmt::Display;
|
||||
use std::str::FromStr;
|
||||
use std::sync::Arc;
|
||||
use storage_broker::proto::SafekeeperTimelineInfo;
|
||||
use storage_broker::proto::TenantTimelineId as ProtoTenantTimelineId;
|
||||
use tokio::task::JoinError;
|
||||
|
||||
use crate::json_ctrl::append_logical_message;
|
||||
use crate::json_ctrl::AppendLogicalMessage;
|
||||
use crate::safekeeper::ServerInfo;
|
||||
use crate::safekeeper::Term;
|
||||
|
||||
@@ -191,6 +194,50 @@ async fn timeline_create_handler(mut request: Request<Body>) -> Result<Response<
|
||||
json_response(StatusCode::OK, ())
|
||||
}
|
||||
|
||||
// Create fake timeline + insert some valid WAL. Useful to test WAL streaming
|
||||
// from safekeeper in isolation, e.g.
|
||||
// pg_receivewal -v -d "host=localhost port=5454 options='-c tenant_id=deadbeefdeadbeefdeadbeefdeadbeef timeline_id=deadbeefdeadbeefdeadbeefdeadbeef'" -D ~/tmp/tmp/tmp
|
||||
// (hacking pg_receivewal startpos is currently needed though to make pg_receivewal work)
|
||||
async fn create_fake_timeline_handler(_request: Request<Body>) -> Result<Response<Body>, ApiError> {
|
||||
let ttid = TenantTimelineId {
|
||||
tenant_id: TenantId::from_str("deadbeefdeadbeefdeadbeefdeadbeef")
|
||||
.expect("timeline_id parsing failed"),
|
||||
timeline_id: TimelineId::from_str("deadbeefdeadbeefdeadbeefdeadbeef")
|
||||
.expect("tenant_id parsing failed"),
|
||||
};
|
||||
let pg_version = 150000;
|
||||
let server_info = ServerInfo {
|
||||
pg_version,
|
||||
system_id: 0,
|
||||
wal_seg_size: WAL_SEGMENT_SIZE as u32,
|
||||
};
|
||||
let init_lsn = Lsn(0x1493AC8);
|
||||
tokio::task::spawn_blocking(move || -> anyhow::Result<()> {
|
||||
let tli = GlobalTimelines::create(ttid, server_info, init_lsn, init_lsn)?;
|
||||
let mut begin_lsn = init_lsn;
|
||||
for _ in 0..16 {
|
||||
let append = AppendLogicalMessage {
|
||||
lm_prefix: "db".to_owned(),
|
||||
lm_message: "hahabubu".to_owned(),
|
||||
set_commit_lsn: true,
|
||||
send_proposer_elected: false, // actually ignored here
|
||||
term: 0,
|
||||
epoch_start_lsn: init_lsn,
|
||||
begin_lsn,
|
||||
truncate_lsn: init_lsn,
|
||||
pg_version,
|
||||
};
|
||||
let inserted = append_logical_message(&tli, &append)?;
|
||||
begin_lsn = inserted.end_lsn;
|
||||
}
|
||||
Ok(())
|
||||
})
|
||||
.await
|
||||
.map_err(|e| ApiError::InternalServerError(e.into()))?
|
||||
.map_err(ApiError::InternalServerError)?;
|
||||
json_response(StatusCode::OK, ())
|
||||
}
|
||||
|
||||
/// Deactivates the timeline and removes its data directory.
|
||||
async fn timeline_delete_force_handler(
|
||||
mut request: Request<Body>,
|
||||
@@ -302,6 +349,7 @@ pub fn make_router(conf: SafeKeeperConf) -> RouterBuilder<hyper::Body, ApiError>
|
||||
.get("/v1/status", status_handler)
|
||||
// Will be used in the future instead of implicit timeline creation
|
||||
.post("/v1/tenant/timeline", timeline_create_handler)
|
||||
.post("/v1/fake_timeline", create_fake_timeline_handler)
|
||||
.get(
|
||||
"/v1/tenant/:tenant_id/timeline/:timeline_id",
|
||||
timeline_status_handler,
|
||||
|
||||
@@ -8,11 +8,12 @@
|
||||
|
||||
use std::sync::Arc;
|
||||
|
||||
use anyhow::Result;
|
||||
use anyhow::Context;
|
||||
use bytes::Bytes;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use tracing::*;
|
||||
use utils::id::TenantTimelineId;
|
||||
use utils::postgres_backend_async::QueryError;
|
||||
|
||||
use crate::handler::SafekeeperPostgresHandler;
|
||||
use crate::safekeeper::{AcceptorProposerMessage, AppendResponse, ServerInfo};
|
||||
@@ -25,29 +26,29 @@ use crate::GlobalTimelines;
|
||||
use postgres_ffi::encode_logical_message;
|
||||
use postgres_ffi::WAL_SEGMENT_SIZE;
|
||||
use pq_proto::{BeMessage, RowDescriptor, TEXT_OID};
|
||||
use utils::{lsn::Lsn, postgres_backend::PostgresBackend};
|
||||
use utils::{lsn::Lsn, postgres_backend_async::PostgresBackend};
|
||||
|
||||
#[derive(Serialize, Deserialize, Debug)]
|
||||
pub struct AppendLogicalMessage {
|
||||
// prefix and message to build LogicalMessage
|
||||
lm_prefix: String,
|
||||
lm_message: String,
|
||||
pub lm_prefix: String,
|
||||
pub lm_message: String,
|
||||
|
||||
// if true, commit_lsn will match flush_lsn after append
|
||||
set_commit_lsn: bool,
|
||||
pub set_commit_lsn: bool,
|
||||
|
||||
// if true, ProposerElected will be sent before append
|
||||
send_proposer_elected: bool,
|
||||
pub send_proposer_elected: bool,
|
||||
|
||||
// fields from AppendRequestHeader
|
||||
term: Term,
|
||||
epoch_start_lsn: Lsn,
|
||||
begin_lsn: Lsn,
|
||||
truncate_lsn: Lsn,
|
||||
pg_version: u32,
|
||||
pub term: Term,
|
||||
pub epoch_start_lsn: Lsn,
|
||||
pub begin_lsn: Lsn,
|
||||
pub truncate_lsn: Lsn,
|
||||
pub pg_version: u32,
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize)]
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
struct AppendResult {
|
||||
// safekeeper state after append
|
||||
state: SafeKeeperState,
|
||||
@@ -58,12 +59,12 @@ struct AppendResult {
|
||||
/// Handles command to craft logical message WAL record with given
|
||||
/// content, and then append it with specified term and lsn. This
|
||||
/// function is used to test safekeepers in different scenarios.
|
||||
pub fn handle_json_ctrl(
|
||||
pub async fn handle_json_ctrl(
|
||||
spg: &SafekeeperPostgresHandler,
|
||||
pgb: &mut PostgresBackend,
|
||||
append_request: &AppendLogicalMessage,
|
||||
) -> Result<()> {
|
||||
info!("JSON_CTRL request: {:?}", append_request);
|
||||
) -> Result<(), QueryError> {
|
||||
info!("JSON_CTRL request: {append_request:?}");
|
||||
|
||||
// need to init safekeeper state before AppendRequest
|
||||
let tli = prepare_safekeeper(spg.ttid, append_request.pg_version)?;
|
||||
@@ -78,22 +79,24 @@ pub fn handle_json_ctrl(
|
||||
state: tli.get_state().1,
|
||||
inserted_wal,
|
||||
};
|
||||
let response_data = serde_json::to_vec(&response)?;
|
||||
let response_data = serde_json::to_vec(&response)
|
||||
.with_context(|| format!("Response {response:?} is not a json array"))?;
|
||||
|
||||
pgb.write_message_noflush(&BeMessage::RowDescription(&[RowDescriptor {
|
||||
pgb.write_message(&BeMessage::RowDescription(&[RowDescriptor {
|
||||
name: b"json",
|
||||
typoid: TEXT_OID,
|
||||
typlen: -1,
|
||||
..Default::default()
|
||||
}]))?
|
||||
.write_message_noflush(&BeMessage::DataRow(&[Some(&response_data)]))?
|
||||
.write_message(&BeMessage::CommandComplete(b"JSON_CTRL"))?;
|
||||
.write_message(&BeMessage::DataRow(&[Some(&response_data)]))?
|
||||
.write_message_flush(&BeMessage::CommandComplete(b"JSON_CTRL"))
|
||||
.await?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Prepare safekeeper to process append requests without crashes,
|
||||
/// by sending ProposerGreeting with default server.wal_seg_size.
|
||||
fn prepare_safekeeper(ttid: TenantTimelineId, pg_version: u32) -> Result<Arc<Timeline>> {
|
||||
fn prepare_safekeeper(ttid: TenantTimelineId, pg_version: u32) -> anyhow::Result<Arc<Timeline>> {
|
||||
GlobalTimelines::create(
|
||||
ttid,
|
||||
ServerInfo {
|
||||
@@ -106,7 +109,7 @@ fn prepare_safekeeper(ttid: TenantTimelineId, pg_version: u32) -> Result<Arc<Tim
|
||||
)
|
||||
}
|
||||
|
||||
fn send_proposer_elected(tli: &Arc<Timeline>, term: Term, lsn: Lsn) -> Result<()> {
|
||||
fn send_proposer_elected(tli: &Arc<Timeline>, term: Term, lsn: Lsn) -> anyhow::Result<()> {
|
||||
// add new term to existing history
|
||||
let history = tli.get_state().1.acceptor_state.term_history;
|
||||
let history = history.up_to(lsn.checked_sub(1u64).unwrap());
|
||||
@@ -125,16 +128,19 @@ fn send_proposer_elected(tli: &Arc<Timeline>, term: Term, lsn: Lsn) -> Result<()
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize)]
|
||||
struct InsertedWAL {
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
pub struct InsertedWAL {
|
||||
begin_lsn: Lsn,
|
||||
end_lsn: Lsn,
|
||||
pub end_lsn: Lsn,
|
||||
append_response: AppendResponse,
|
||||
}
|
||||
|
||||
/// Extend local WAL with new LogicalMessage record. To do that,
|
||||
/// create AppendRequest with new WAL and pass it to safekeeper.
|
||||
fn append_logical_message(tli: &Arc<Timeline>, msg: &AppendLogicalMessage) -> Result<InsertedWAL> {
|
||||
pub fn append_logical_message(
|
||||
tli: &Arc<Timeline>,
|
||||
msg: &AppendLogicalMessage,
|
||||
) -> anyhow::Result<InsertedWAL> {
|
||||
let wal_data = encode_logical_message(&msg.lm_prefix, &msg.lm_message);
|
||||
let sk_state = tli.get_state().1;
|
||||
|
||||
|
||||
@@ -2,11 +2,13 @@
|
||||
//! Gets messages from the network, passes them down to consensus module and
|
||||
//! sends replies back.
|
||||
|
||||
use anyhow::{anyhow, bail, Result};
|
||||
use anyhow::anyhow;
|
||||
use anyhow::Context;
|
||||
|
||||
use bytes::BytesMut;
|
||||
use tracing::*;
|
||||
use utils::lsn::Lsn;
|
||||
use utils::postgres_backend_async::QueryError;
|
||||
|
||||
use crate::safekeeper::ServerInfo;
|
||||
use crate::timeline::Timeline;
|
||||
@@ -24,7 +26,7 @@ use crate::safekeeper::ProposerAcceptorMessage;
|
||||
|
||||
use crate::handler::SafekeeperPostgresHandler;
|
||||
use pq_proto::{BeMessage, FeMessage};
|
||||
use utils::{postgres_backend::PostgresBackend, sock_split::ReadStream};
|
||||
use utils::{postgres_backend_async::PostgresBackend, sock_split::ReadStream};
|
||||
|
||||
pub struct ReceiveWalConn<'pg> {
|
||||
/// Postgres connection
|
||||
@@ -43,7 +45,7 @@ impl<'pg> ReceiveWalConn<'pg> {
|
||||
}
|
||||
|
||||
// Send message to the postgres
|
||||
fn write_msg(&mut self, msg: &AcceptorProposerMessage) -> Result<()> {
|
||||
fn write_msg(&mut self, msg: &AcceptorProposerMessage) -> anyhow::Result<()> {
|
||||
let mut buf = BytesMut::with_capacity(128);
|
||||
msg.serialize(&mut buf)?;
|
||||
self.pg_backend.write_message(&BeMessage::CopyData(&buf))?;
|
||||
@@ -51,129 +53,143 @@ impl<'pg> ReceiveWalConn<'pg> {
|
||||
}
|
||||
|
||||
/// Receive WAL from wal_proposer
|
||||
pub fn run(&mut self, spg: &mut SafekeeperPostgresHandler) -> Result<()> {
|
||||
pub fn run(&mut self, spg: &mut SafekeeperPostgresHandler) -> Result<(), QueryError> {
|
||||
let _enter = info_span!("WAL acceptor", ttid = %spg.ttid).entered();
|
||||
|
||||
// Notify the libpq client that it's allowed to send `CopyData` messages
|
||||
self.pg_backend
|
||||
.write_message(&BeMessage::CopyBothResponse)?;
|
||||
|
||||
let r = self
|
||||
.pg_backend
|
||||
.take_stream_in()
|
||||
.ok_or_else(|| anyhow!("failed to take read stream from pgbackend"))?;
|
||||
let mut poll_reader = ProposerPollStream::new(r)?;
|
||||
Ok(())
|
||||
// let r = self
|
||||
// .pg_backend
|
||||
// .take_stream_in()
|
||||
// .ok_or_else(|| anyhow!("failed to take read stream from pgbackend"))?;
|
||||
// let mut poll_reader = ProposerPollStream::new(r)?;
|
||||
|
||||
// Receive information about server
|
||||
let next_msg = poll_reader.recv_msg()?;
|
||||
let tli = match next_msg {
|
||||
ProposerAcceptorMessage::Greeting(ref greeting) => {
|
||||
info!(
|
||||
"start handshake with walproposer {} sysid {} timeline {}",
|
||||
self.peer_addr, greeting.system_id, greeting.tli,
|
||||
);
|
||||
let server_info = ServerInfo {
|
||||
pg_version: greeting.pg_version,
|
||||
system_id: greeting.system_id,
|
||||
wal_seg_size: greeting.wal_seg_size,
|
||||
};
|
||||
GlobalTimelines::create(spg.ttid, server_info, Lsn::INVALID, Lsn::INVALID)?
|
||||
}
|
||||
_ => bail!("unexpected message {:?} instead of greeting", next_msg),
|
||||
};
|
||||
// let next_msg = poll_reader.recv_msg()?;
|
||||
// let tli = match next_msg {
|
||||
// ProposerAcceptorMessage::Greeting(ref greeting) => {
|
||||
// info!(
|
||||
// "start handshake with walproposer {} sysid {} timeline {}",
|
||||
// self.peer_addr, greeting.system_id, greeting.tli,
|
||||
// );
|
||||
// let server_info = ServerInfo {
|
||||
// pg_version: greeting.pg_version,
|
||||
// system_id: greeting.system_id,
|
||||
// wal_seg_size: greeting.wal_seg_size,
|
||||
// };
|
||||
// GlobalTimelines::create(spg.ttid, server_info, Lsn::INVALID, Lsn::INVALID)?
|
||||
// }
|
||||
// _ => {
|
||||
// return Err(QueryError::Other(anyhow::anyhow!(
|
||||
// "unexpected message {next_msg:?} instead of greeting"
|
||||
// )))
|
||||
// }
|
||||
// };
|
||||
|
||||
let mut next_msg = Some(next_msg);
|
||||
// let mut next_msg = None;
|
||||
|
||||
let mut first_time_through = true;
|
||||
let mut _guard: Option<ComputeConnectionGuard> = None;
|
||||
loop {
|
||||
if matches!(next_msg, Some(ProposerAcceptorMessage::AppendRequest(_))) {
|
||||
// poll AppendRequest's without blocking and write WAL to disk without flushing,
|
||||
// while it's readily available
|
||||
while let Some(ProposerAcceptorMessage::AppendRequest(append_request)) = next_msg {
|
||||
let msg = ProposerAcceptorMessage::NoFlushAppendRequest(append_request);
|
||||
// let mut first_time_through = true;
|
||||
// let mut _guard: Option<ComputeConnectionGuard> = None;
|
||||
// loop {
|
||||
// if matches!(next_msg, Some(ProposerAcceptorMessage::AppendRequest(_))) {
|
||||
// // poll AppendRequest's without blocking and write WAL to disk without flushing,
|
||||
// // while it's readily available
|
||||
// while let Some(ProposerAcceptorMessage::AppendRequest(append_request)) = next_msg {
|
||||
// let msg = ProposerAcceptorMessage::NoFlushAppendRequest(append_request);
|
||||
|
||||
let reply = tli.process_msg(&msg)?;
|
||||
if let Some(reply) = reply {
|
||||
self.write_msg(&reply)?;
|
||||
}
|
||||
// let reply = tli.process_msg(&msg)?;
|
||||
// if let Some(reply) = reply {
|
||||
// self.write_msg(&reply)?;
|
||||
// }
|
||||
|
||||
next_msg = poll_reader.poll_msg();
|
||||
}
|
||||
// // next_msg = poll_reader.poll_msg();
|
||||
// next_msg = poll_reader.poll_msg();
|
||||
// }
|
||||
|
||||
// flush all written WAL to the disk
|
||||
let reply = tli.process_msg(&ProposerAcceptorMessage::FlushWAL)?;
|
||||
if let Some(reply) = reply {
|
||||
self.write_msg(&reply)?;
|
||||
}
|
||||
} else if let Some(msg) = next_msg.take() {
|
||||
// process other message
|
||||
let reply = tli.process_msg(&msg)?;
|
||||
if let Some(reply) = reply {
|
||||
self.write_msg(&reply)?;
|
||||
}
|
||||
}
|
||||
if first_time_through {
|
||||
// Register the connection and defer unregister. Do that only
|
||||
// after processing first message, as it sets wal_seg_size,
|
||||
// wanted by many.
|
||||
tli.on_compute_connect()?;
|
||||
_guard = Some(ComputeConnectionGuard {
|
||||
timeline: Arc::clone(&tli),
|
||||
});
|
||||
first_time_through = false;
|
||||
}
|
||||
// // flush all written WAL to the disk
|
||||
// let reply = tli.process_msg(&ProposerAcceptorMessage::FlushWAL)?;
|
||||
// if let Some(reply) = reply {
|
||||
// self.write_msg(&reply)?;
|
||||
// }
|
||||
// } else if let Some(msg) = next_msg.take() {
|
||||
// // process other message
|
||||
// let reply = tli.process_msg(&msg)?;
|
||||
// if let Some(reply) = reply {
|
||||
// self.write_msg(&reply)?;
|
||||
// }
|
||||
// }
|
||||
// if first_time_through {
|
||||
// // Register the connection and defer unregister. Do that only
|
||||
// // after processing first message, as it sets wal_seg_size,
|
||||
// // wanted by many.
|
||||
// tli.on_compute_connect()?;
|
||||
// _guard = Some(ComputeConnectionGuard {
|
||||
// timeline: Arc::clone(&tli),
|
||||
// });
|
||||
// first_time_through = false;
|
||||
// }
|
||||
|
||||
// blocking wait for the next message
|
||||
if next_msg.is_none() {
|
||||
next_msg = Some(poll_reader.recv_msg()?);
|
||||
}
|
||||
}
|
||||
// // blocking wait for the next message
|
||||
// if next_msg.is_none() {
|
||||
// next_msg = Some(poll_reader.recv_msg()?);
|
||||
// }
|
||||
// }
|
||||
}
|
||||
}
|
||||
|
||||
struct ProposerPollStream {
|
||||
msg_rx: Receiver<ProposerAcceptorMessage>,
|
||||
read_thread: Option<thread::JoinHandle<Result<()>>>,
|
||||
read_thread: Option<thread::JoinHandle<Result<(), QueryError>>>,
|
||||
}
|
||||
|
||||
impl ProposerPollStream {
|
||||
fn new(mut r: ReadStream) -> Result<Self> {
|
||||
let (msg_tx, msg_rx) = channel();
|
||||
// fn new(mut r: ReadStream) -> anyhow::Result<Self> {
|
||||
// let (msg_tx, msg_rx) = channel();
|
||||
|
||||
let read_thread = thread::Builder::new()
|
||||
.name("Read WAL thread".into())
|
||||
.spawn(move || -> Result<()> {
|
||||
loop {
|
||||
let copy_data = match FeMessage::read(&mut r)? {
|
||||
Some(FeMessage::CopyData(bytes)) => bytes,
|
||||
Some(msg) => bail!("expected `CopyData` message, found {:?}", msg),
|
||||
None => bail!("connection closed unexpectedly"),
|
||||
};
|
||||
// let read_thread = thread::Builder::new()
|
||||
// .name("Read WAL thread".into())
|
||||
// .spawn(move || -> Result<(), QueryError> {
|
||||
// loop {
|
||||
// let copy_data = match FeMessage::read(&mut r)? {
|
||||
// Some(FeMessage::CopyData(bytes)) => Ok(bytes),
|
||||
// Some(msg) => Err(QueryError::Other(anyhow::anyhow!(
|
||||
// "expected `CopyData` message, found {msg:?}"
|
||||
// ))),
|
||||
// None => Err(QueryError::from(std::io::Error::new(
|
||||
// std::io::ErrorKind::ConnectionAborted,
|
||||
// "walproposer closed the connection",
|
||||
// ))),
|
||||
// }?;
|
||||
|
||||
let msg = ProposerAcceptorMessage::parse(copy_data)?;
|
||||
msg_tx.send(msg)?;
|
||||
}
|
||||
// msg_tx will be dropped here, this will also close msg_rx
|
||||
})?;
|
||||
// let msg = ProposerAcceptorMessage::parse(copy_data)?;
|
||||
// msg_tx
|
||||
// .send(msg)
|
||||
// .context("Failed to send the proposer message")?;
|
||||
// }
|
||||
// // msg_tx will be dropped here, this will also close msg_rx
|
||||
// })?;
|
||||
|
||||
Ok(Self {
|
||||
msg_rx,
|
||||
read_thread: Some(read_thread),
|
||||
})
|
||||
}
|
||||
// Ok(Self {
|
||||
// msg_rx,
|
||||
// read_thread: Some(read_thread),
|
||||
// })
|
||||
// }
|
||||
|
||||
fn recv_msg(&mut self) -> Result<ProposerAcceptorMessage> {
|
||||
fn recv_msg(&mut self) -> Result<ProposerAcceptorMessage, QueryError> {
|
||||
self.msg_rx.recv().map_err(|_| {
|
||||
// return error from the read thread
|
||||
let res = match self.read_thread.take() {
|
||||
Some(thread) => thread.join(),
|
||||
None => return anyhow!("read thread is gone"),
|
||||
None => return QueryError::Other(anyhow::anyhow!("read thread is gone")),
|
||||
};
|
||||
|
||||
match res {
|
||||
Ok(Ok(())) => anyhow!("unexpected result from read thread"),
|
||||
Err(err) => anyhow!("read thread panicked: {:?}", err),
|
||||
Ok(Ok(())) => {
|
||||
QueryError::Other(anyhow::anyhow!("unexpected result from read thread"))
|
||||
}
|
||||
Err(err) => QueryError::Other(anyhow::anyhow!("read thread panicked: {err:?}")),
|
||||
Ok(Err(err)) => err,
|
||||
}
|
||||
})
|
||||
|
||||
@@ -1,27 +1,36 @@
|
||||
//! This module implements the streaming side of replication protocol, starting
|
||||
//! with the "START_REPLICATION" message.
|
||||
|
||||
use anyhow::Context as AnyhowContext;
|
||||
use bytes::Bytes;
|
||||
use futures::future::BoxFuture;
|
||||
use pin_project_lite::pin_project;
|
||||
use postgres_ffi::get_current_timestamp;
|
||||
use postgres_ffi::{TimestampTz, MAX_SEND_SIZE};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::cell::RefCell;
|
||||
use std::cmp::min;
|
||||
use std::future::Future;
|
||||
use std::pin::Pin;
|
||||
use std::rc::Rc;
|
||||
use std::sync::Arc;
|
||||
use std::task::{ready, Context, Poll};
|
||||
use std::time::Duration;
|
||||
use std::{io, str, thread};
|
||||
use tokio::sync::watch::Receiver;
|
||||
use tokio::time::timeout;
|
||||
use tracing::*;
|
||||
use utils::postgres_backend_async::QueryError;
|
||||
use utils::send_rc::RefCellSend;
|
||||
use utils::send_rc::SendRc;
|
||||
|
||||
use pq_proto::{BeMessage, FeMessage, ReplicationFeedback, WalSndKeepAlive, XLogDataBody};
|
||||
use utils::{bin_ser::BeSer, lsn::Lsn, postgres_backend_async::PostgresBackend};
|
||||
|
||||
use crate::handler::SafekeeperPostgresHandler;
|
||||
use crate::timeline::{ReplicaState, Timeline};
|
||||
use crate::wal_storage::WalReader;
|
||||
use crate::GlobalTimelines;
|
||||
use anyhow::{bail, Context, Result};
|
||||
|
||||
use bytes::Bytes;
|
||||
use postgres_ffi::get_current_timestamp;
|
||||
use postgres_ffi::{TimestampTz, MAX_SEND_SIZE};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::cmp::min;
|
||||
use std::net::Shutdown;
|
||||
use std::sync::Arc;
|
||||
use std::time::Duration;
|
||||
use std::{str, thread};
|
||||
|
||||
use pq_proto::{BeMessage, FeMessage, ReplicationFeedback, WalSndKeepAlive, XLogDataBody};
|
||||
use tokio::sync::watch::Receiver;
|
||||
use tokio::time::timeout;
|
||||
use tracing::*;
|
||||
use utils::{bin_ser::BeSer, lsn::Lsn, postgres_backend::PostgresBackend, sock_split::ReadStream};
|
||||
|
||||
// See: https://www.postgresql.org/docs/13/protocol-replication.html
|
||||
const HOT_STANDBY_FEEDBACK_TAG_BYTE: u8 = b'h';
|
||||
@@ -59,13 +68,6 @@ pub struct StandbyReply {
|
||||
pub reply_requested: bool,
|
||||
}
|
||||
|
||||
/// A network connection that's speaking the replication protocol.
|
||||
pub struct ReplicationConn {
|
||||
/// This is an `Option` because we will spawn a background thread that will
|
||||
/// `take` it from us.
|
||||
stream_in: Option<ReadStream>,
|
||||
}
|
||||
|
||||
/// Scope guard to unregister replication connection from timeline
|
||||
struct ReplicationConnGuard {
|
||||
replica: usize, // replica internal ID assigned by timeline
|
||||
@@ -78,230 +80,419 @@ impl Drop for ReplicationConnGuard {
|
||||
}
|
||||
}
|
||||
|
||||
impl ReplicationConn {
|
||||
/// Create a new `ReplicationConn`
|
||||
pub fn new(pgb: &mut PostgresBackend) -> Self {
|
||||
Self {
|
||||
stream_in: pgb.take_stream_in(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Handle incoming messages from the network.
|
||||
/// This is spawned into the background by `handle_start_replication`.
|
||||
fn background_thread(
|
||||
mut stream_in: ReadStream,
|
||||
replica_guard: Arc<ReplicationConnGuard>,
|
||||
) -> Result<()> {
|
||||
let replica_id = replica_guard.replica;
|
||||
let timeline = &replica_guard.timeline;
|
||||
|
||||
let mut state = ReplicaState::new();
|
||||
// Wait for replica's feedback.
|
||||
while let Some(msg) = FeMessage::read(&mut stream_in)? {
|
||||
match &msg {
|
||||
FeMessage::CopyData(m) => {
|
||||
// There's three possible data messages that the client is supposed to send here:
|
||||
// `HotStandbyFeedback` and `StandbyStatusUpdate` and `NeonStandbyFeedback`.
|
||||
|
||||
match m.first().cloned() {
|
||||
Some(HOT_STANDBY_FEEDBACK_TAG_BYTE) => {
|
||||
// Note: deserializing is on m[1..] because we skip the tag byte.
|
||||
state.hs_feedback = HotStandbyFeedback::des(&m[1..])
|
||||
.context("failed to deserialize HotStandbyFeedback")?;
|
||||
timeline.update_replica_state(replica_id, state);
|
||||
}
|
||||
Some(STANDBY_STATUS_UPDATE_TAG_BYTE) => {
|
||||
let _reply = StandbyReply::des(&m[1..])
|
||||
.context("failed to deserialize StandbyReply")?;
|
||||
// This must be a regular postgres replica,
|
||||
// because pageserver doesn't send this type of messages to safekeeper.
|
||||
// Currently this is not implemented, so this message is ignored.
|
||||
|
||||
warn!("unexpected StandbyReply. Read-only postgres replicas are not supported in safekeepers yet.");
|
||||
// timeline.update_replica_state(replica_id, Some(state));
|
||||
}
|
||||
Some(NEON_STATUS_UPDATE_TAG_BYTE) => {
|
||||
// Note: deserializing is on m[9..] because we skip the tag byte and len bytes.
|
||||
let buf = Bytes::copy_from_slice(&m[9..]);
|
||||
let reply = ReplicationFeedback::parse(buf);
|
||||
|
||||
trace!("ReplicationFeedback is {:?}", reply);
|
||||
// Only pageserver sends ReplicationFeedback, so set the flag.
|
||||
// This replica is the source of information to resend to compute.
|
||||
state.pageserver_feedback = Some(reply);
|
||||
|
||||
timeline.update_replica_state(replica_id, state);
|
||||
}
|
||||
_ => warn!("unexpected message {:?}", msg),
|
||||
}
|
||||
}
|
||||
FeMessage::Sync => {}
|
||||
FeMessage::CopyFail => {
|
||||
// Shutdown the connection, because rust-postgres client cannot be dropped
|
||||
// when connection is alive.
|
||||
let _ = stream_in.shutdown(Shutdown::Both);
|
||||
bail!("Copy failed");
|
||||
}
|
||||
_ => {
|
||||
// We only handle `CopyData`, 'Sync', 'CopyFail' messages. Anything else is ignored.
|
||||
info!("unexpected message {:?}", msg);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
///
|
||||
/// Handle START_REPLICATION replication command
|
||||
///
|
||||
pub fn run(
|
||||
impl SafekeeperPostgresHandler {
|
||||
pub async fn handle_start_replication(
|
||||
&mut self,
|
||||
spg: &mut SafekeeperPostgresHandler,
|
||||
pgb: &mut PostgresBackend,
|
||||
mut start_pos: Lsn,
|
||||
) -> Result<()> {
|
||||
let _enter = info_span!("WAL sender", ttid = %spg.ttid).entered();
|
||||
|
||||
let tli = GlobalTimelines::get(spg.ttid)?;
|
||||
|
||||
// spawn the background thread which receives HotStandbyFeedback messages.
|
||||
let bg_timeline = Arc::clone(&tli);
|
||||
let bg_stream_in = self.stream_in.take().unwrap();
|
||||
let bg_timeline_id = spg.timeline_id.unwrap();
|
||||
start_pos: Lsn,
|
||||
) -> Result<(), QueryError> {
|
||||
let appname = self.appname.clone();
|
||||
let tli = GlobalTimelines::get(self.ttid)?;
|
||||
|
||||
let state = ReplicaState::new();
|
||||
// This replica_id is used below to check if it's time to stop replication.
|
||||
let replica_id = bg_timeline.add_replica(state);
|
||||
let replica_id = tli.add_replica(state);
|
||||
|
||||
// Use a guard object to remove our entry from the timeline, when the background
|
||||
// thread and us have both finished using it.
|
||||
let replica_guard = Arc::new(ReplicationConnGuard {
|
||||
let _guard = Arc::new(ReplicationConnGuard {
|
||||
replica: replica_id,
|
||||
timeline: bg_timeline,
|
||||
timeline: tli.clone(),
|
||||
});
|
||||
let bg_replica_guard = Arc::clone(&replica_guard);
|
||||
|
||||
// TODO: here we got two threads, one for writing WAL and one for receiving
|
||||
// feedback. If one of them fails, we should shutdown the other one too.
|
||||
let _ = thread::Builder::new()
|
||||
.name("HotStandbyFeedback thread".into())
|
||||
.spawn(move || {
|
||||
let _enter =
|
||||
info_span!("HotStandbyFeedback thread", timeline = %bg_timeline_id).entered();
|
||||
if let Err(err) = Self::background_thread(bg_stream_in, bg_replica_guard) {
|
||||
error!("Replication background thread failed: {}", err);
|
||||
// Walproposer gets special handling: safekeeper must give proposer all
|
||||
// local WAL till the end, whether committed or not (walproposer will
|
||||
// hang otherwise). That's because walproposer runs the consensus and
|
||||
// synchronizes safekeepers on the most advanced one.
|
||||
//
|
||||
// There is a small risk of this WAL getting concurrently garbaged if
|
||||
// another compute rises which collects majority and starts fixing log
|
||||
// on this safekeeper itself. That's ok as (old) proposer will never be
|
||||
// able to commit such WAL.
|
||||
let stop_pos: Option<Lsn> = if self.is_walproposer_recovery() {
|
||||
let wal_end = tli.get_flush_lsn();
|
||||
Some(wal_end)
|
||||
} else {
|
||||
None
|
||||
};
|
||||
let end_pos = stop_pos.unwrap_or(Lsn::INVALID);
|
||||
|
||||
info!(
|
||||
"starting streaming from {:?} till {:?}",
|
||||
start_pos, stop_pos
|
||||
);
|
||||
|
||||
// switch to copy
|
||||
pgb.write_message(&BeMessage::CopyBothResponse)?;
|
||||
|
||||
let (_, persisted_state) = tli.get_state();
|
||||
let wal_reader = WalReader::new(
|
||||
self.conf.workdir.clone(),
|
||||
self.conf.timeline_dir(&tli.ttid),
|
||||
&persisted_state,
|
||||
start_pos,
|
||||
self.conf.wal_backup_enabled,
|
||||
)?;
|
||||
let write_ctx = SendRc::new(WriteContext {
|
||||
wal_reader: RefCell::new(wal_reader),
|
||||
send_buf: RefCell::new([0; MAX_SEND_SIZE]),
|
||||
});
|
||||
|
||||
let mut c = ReplicationContext {
|
||||
tli,
|
||||
replica_id,
|
||||
appname,
|
||||
pgb,
|
||||
start_pos,
|
||||
end_pos,
|
||||
stop_pos,
|
||||
write_ctx,
|
||||
feedback: ReplicaState::new(),
|
||||
};
|
||||
|
||||
let _phantom_wf = c.wait_wal_fut();
|
||||
let real_end_pos = c.end_pos;
|
||||
c.end_pos = c.start_pos + 1; // to well form read_wal future
|
||||
let _phantom_rf = c.read_wal_fut();
|
||||
c.end_pos = real_end_pos;
|
||||
|
||||
ReplicationHandler {
|
||||
c,
|
||||
write_state: WriteState::FlushWal,
|
||||
_phantom_wf,
|
||||
_phantom_rf,
|
||||
}
|
||||
.await
|
||||
}
|
||||
}
|
||||
|
||||
pin_project! {
|
||||
/// START_REPLICATION stream driver: sends WAL and receives feedback.
|
||||
struct ReplicationHandler<'a, WF, RF>
|
||||
where
|
||||
WF: Future<Output = anyhow::Result<Option<Lsn>>>,
|
||||
RF: Future<Output = anyhow::Result<usize>>,
|
||||
{
|
||||
c: ReplicationContext<'a>,
|
||||
#[pin]
|
||||
write_state: WriteState<WF, RF>,
|
||||
// To deduce anonymous types.
|
||||
_phantom_wf: WF,
|
||||
_phantom_rf: RF,
|
||||
}
|
||||
}
|
||||
|
||||
/// Data ReplicationHandler maintains. Separated so we could generate WriteState
|
||||
/// futures during init, deducing their type.
|
||||
struct ReplicationContext<'a> {
|
||||
tli: Arc<Timeline>,
|
||||
appname: Option<String>,
|
||||
replica_id: usize,
|
||||
pgb: &'a mut PostgresBackend,
|
||||
// Position since which we are sending next chunk.
|
||||
start_pos: Lsn,
|
||||
// WAL up to this position is known to be locally available.
|
||||
end_pos: Lsn,
|
||||
// If present, terminate after reaching this position; used by walproposer
|
||||
// in recovery.
|
||||
stop_pos: Option<Lsn>,
|
||||
// This data is needed to create Future sending WAL, so we need to both have
|
||||
// it here (to create new future) and borrow it to the future itself.
|
||||
// Essentially this is a self referential struct. To satisfy borrow checker,
|
||||
// use Rc<RefCell>. To make ReplicationHandler itself Send'able future, wrap
|
||||
// it into SendRc; this is safe as ReplicationHandler is passed between
|
||||
// threads only as a whole (during rescheduling).
|
||||
//
|
||||
// Right now we're in CurrentThread runtime, so Send is somewhat redundant;
|
||||
// however, otherwise we'd need to inconveniently have separate !Send
|
||||
// version of pg backend Handler trait (and work with LocalSet).
|
||||
write_ctx: SendRc<WriteContext>,
|
||||
feedback: ReplicaState,
|
||||
}
|
||||
|
||||
// State which ReplicationHandler needs to create futures sending data.
|
||||
struct WriteContext {
|
||||
wal_reader: RefCell<WalReader>,
|
||||
// buffer for readling WAL into to send it
|
||||
send_buf: RefCell<[u8; MAX_SEND_SIZE]>,
|
||||
}
|
||||
|
||||
// Yield points of WAL sending machinery.
|
||||
pin_project! {
|
||||
#[project = WriteStateProj]
|
||||
enum WriteState<WF, RF>
|
||||
where
|
||||
WF: Future<Output = anyhow::Result<Option<Lsn>>>,
|
||||
RF: Future<Output = anyhow::Result<usize>>,
|
||||
{
|
||||
WaitWal{ #[pin] fut: WF},
|
||||
ReadWal{ #[pin] fut: RF},
|
||||
FlushWal,
|
||||
}
|
||||
}
|
||||
|
||||
impl<WF, RF> Future for ReplicationHandler<'_, WF, RF>
|
||||
where
|
||||
WF: Future<Output = anyhow::Result<Option<Lsn>>>,
|
||||
RF: Future<Output = anyhow::Result<usize>>,
|
||||
{
|
||||
type Output = Result<(), QueryError>;
|
||||
|
||||
// We need to read feedback from the socket and write data there at the same
|
||||
// time. To avoid having to split socket, which creates messy split-join
|
||||
// APIs, is problematic with TLS [1] and needs to manage two tasks, just run
|
||||
// single task and use poll interfaces, basically manual state machine,
|
||||
// which is simple here.
|
||||
//
|
||||
// [1] https://github.com/tokio-rs/tls/issues/40
|
||||
//
|
||||
// Completes only when the stream is over, technically on error currently.
|
||||
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
|
||||
if let Poll::Ready(r) = self.as_mut().poll_read(cx) {
|
||||
return Poll::Ready(r);
|
||||
}
|
||||
self.as_mut().poll_write(cx)
|
||||
}
|
||||
}
|
||||
|
||||
impl<WF, RF> ReplicationHandler<'_, WF, RF>
|
||||
where
|
||||
WF: Future<Output = anyhow::Result<Option<Lsn>>>,
|
||||
RF: Future<Output = anyhow::Result<usize>>,
|
||||
{
|
||||
// Poll reading, i.e. getting feedback and processing it. Completes only on error/end of stream.
|
||||
fn poll_read(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), QueryError>> {
|
||||
loop {
|
||||
match ready!(self.as_mut().project().c.pgb.poll_read_message(cx)) {
|
||||
Ok(Some(msg)) => self.as_mut().handle_feedback(&msg)?,
|
||||
Ok(None) => {
|
||||
return Poll::Ready(Err(QueryError::Other(anyhow::anyhow!(
|
||||
"EOF on replication stream"
|
||||
))))
|
||||
}
|
||||
})?;
|
||||
|
||||
let runtime = tokio::runtime::Builder::new_current_thread()
|
||||
.enable_all()
|
||||
.build()?;
|
||||
|
||||
runtime.block_on(async move {
|
||||
let (inmem_state, persisted_state) = tli.get_state();
|
||||
// add persisted_state.timeline_start_lsn == Lsn(0) check
|
||||
|
||||
// Walproposer gets special handling: safekeeper must give proposer all
|
||||
// local WAL till the end, whether committed or not (walproposer will
|
||||
// hang otherwise). That's because walproposer runs the consensus and
|
||||
// synchronizes safekeepers on the most advanced one.
|
||||
//
|
||||
// There is a small risk of this WAL getting concurrently garbaged if
|
||||
// another compute rises which collects majority and starts fixing log
|
||||
// on this safekeeper itself. That's ok as (old) proposer will never be
|
||||
// able to commit such WAL.
|
||||
let stop_pos: Option<Lsn> = if spg.is_walproposer_recovery() {
|
||||
let wal_end = tli.get_flush_lsn();
|
||||
Some(wal_end)
|
||||
} else {
|
||||
None
|
||||
Err(err) => return Poll::Ready(Err(err.into())),
|
||||
};
|
||||
}
|
||||
}
|
||||
|
||||
info!("Start replication from {:?} till {:?}", start_pos, stop_pos);
|
||||
|
||||
// switch to copy
|
||||
pgb.write_message(&BeMessage::CopyBothResponse)?;
|
||||
|
||||
let mut end_pos = stop_pos.unwrap_or(inmem_state.commit_lsn);
|
||||
|
||||
let mut wal_reader = WalReader::new(
|
||||
spg.conf.workdir.clone(),
|
||||
spg.conf.timeline_dir(&tli.ttid),
|
||||
&persisted_state,
|
||||
start_pos,
|
||||
spg.conf.wal_backup_enabled,
|
||||
)?;
|
||||
|
||||
// buffer for wal sending, limited by MAX_SEND_SIZE
|
||||
let mut send_buf = vec![0u8; MAX_SEND_SIZE];
|
||||
|
||||
// watcher for commit_lsn updates
|
||||
let mut commit_lsn_watch_rx = tli.get_commit_lsn_watch_rx();
|
||||
|
||||
loop {
|
||||
if let Some(stop_pos) = stop_pos {
|
||||
if start_pos >= stop_pos {
|
||||
break; /* recovery finished */
|
||||
fn handle_feedback(self: Pin<&mut Self>, msg: &FeMessage) -> Result<(), QueryError> {
|
||||
let this = self.project();
|
||||
match &msg {
|
||||
FeMessage::CopyData(m) => {
|
||||
// There's three possible data messages that the client is supposed to send here:
|
||||
// `HotStandbyFeedback` and `StandbyStatusUpdate` and `NeonStandbyFeedback`.
|
||||
match m.first().cloned() {
|
||||
Some(HOT_STANDBY_FEEDBACK_TAG_BYTE) => {
|
||||
// Note: deserializing is on m[1..] because we skip the tag byte.
|
||||
this.c.feedback.hs_feedback = HotStandbyFeedback::des(&m[1..])
|
||||
.context("failed to deserialize HotStandbyFeedback")?;
|
||||
this.c
|
||||
.tli
|
||||
.update_replica_state(this.c.replica_id, this.c.feedback);
|
||||
}
|
||||
end_pos = stop_pos;
|
||||
} else {
|
||||
/* Wait until we have some data to stream */
|
||||
let lsn = wait_for_lsn(&mut commit_lsn_watch_rx, start_pos).await?;
|
||||
Some(STANDBY_STATUS_UPDATE_TAG_BYTE) => {
|
||||
let _reply = StandbyReply::des(&m[1..])
|
||||
.context("failed to deserialize StandbyReply")?;
|
||||
// This must be a regular postgres replica,
|
||||
// because pageserver doesn't send this type of messages to safekeeper.
|
||||
// Currently we just ignore this, tracking progress for them is not supported.
|
||||
}
|
||||
Some(NEON_STATUS_UPDATE_TAG_BYTE) => {
|
||||
// Note: deserializing is on m[9..] because we skip the tag byte and len bytes.
|
||||
let buf = Bytes::copy_from_slice(&m[9..]);
|
||||
let reply = ReplicationFeedback::parse(buf);
|
||||
|
||||
if let Some(lsn) = lsn {
|
||||
end_pos = lsn;
|
||||
} else {
|
||||
// TODO: also check once in a while whether we are walsender
|
||||
// to right pageserver.
|
||||
if tli.should_walsender_stop(replica_id) {
|
||||
// Shut down, timeline is suspended.
|
||||
// TODO create proper error type for this
|
||||
bail!("end streaming to {:?}", spg.appname);
|
||||
}
|
||||
trace!("ReplicationFeedback is {:?}", reply);
|
||||
// Only pageserver sends ReplicationFeedback, so set the flag.
|
||||
// This replica is the source of information to resend to compute.
|
||||
this.c.feedback.pageserver_feedback = Some(reply);
|
||||
|
||||
// timeout expired: request pageserver status
|
||||
pgb.write_message(&BeMessage::KeepAlive(WalSndKeepAlive {
|
||||
sent_ptr: end_pos.0,
|
||||
timestamp: get_current_timestamp(),
|
||||
request_reply: true,
|
||||
}))
|
||||
.context("Failed to send KeepAlive message")?;
|
||||
this.c
|
||||
.tli
|
||||
.update_replica_state(this.c.replica_id, this.c.feedback);
|
||||
}
|
||||
_ => warn!("unexpected message {:?}", msg),
|
||||
}
|
||||
}
|
||||
FeMessage::CopyFail => {
|
||||
// XXX we should probably (tell pgb to) close the socket, as
|
||||
// CopyFail in duplex copy is somewhat unexpected (at least to
|
||||
// PG walsender; evidently client should finish it with
|
||||
// CopyDone). Note that sync rust-postgres client (which we
|
||||
// don't use anymore) hangs otherwise.
|
||||
// https://github.com/sfackler/rust-postgres/issues/755
|
||||
// https://github.com/neondatabase/neon/issues/935
|
||||
//
|
||||
return Err(anyhow::anyhow!("unexpected CopyFail").into());
|
||||
}
|
||||
_ => {
|
||||
return Err(
|
||||
anyhow::anyhow!("unexpected message {:?} in replication stream", msg).into(),
|
||||
);
|
||||
}
|
||||
};
|
||||
Ok(())
|
||||
}
|
||||
|
||||
// Poll writing, i.e. sending more WAL. Completes only on error or when we
|
||||
// decide to shutdown connection -- receiver is caughtup and there is no
|
||||
// active computes; this is still handled as Err though.
|
||||
fn poll_write(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), QueryError>> {
|
||||
// send while we don't block or error out
|
||||
loop {
|
||||
match &mut self.as_mut().project().write_state.project() {
|
||||
WriteStateProj::WaitWal { fut } => match ready!(fut.as_mut().poll(cx))? {
|
||||
Some(lsn) => {
|
||||
self.as_mut().project().c.end_pos = lsn;
|
||||
self.as_mut().start_read_wal();
|
||||
continue;
|
||||
}
|
||||
// Timed out waiting for WAL, send keepalive and possibly terminate.
|
||||
None => {
|
||||
let mut this = self.as_mut().project();
|
||||
if this.c.tli.should_walsender_stop(this.c.replica_id) {
|
||||
// Terminate if there is nothing more to send.
|
||||
// TODO close the stream properly
|
||||
return Poll::Ready(Err(anyhow::anyhow!(format!(
|
||||
"ending streaming to {:?} at {}, receiver is caughtup and there is no computes",
|
||||
self.c.appname, self.c.start_pos,
|
||||
)).into()));
|
||||
}
|
||||
let end_pos = this.c.end_pos.0;
|
||||
this.c
|
||||
.pgb
|
||||
.write_message(&BeMessage::KeepAlive(WalSndKeepAlive {
|
||||
sent_ptr: end_pos,
|
||||
timestamp: get_current_timestamp(),
|
||||
request_reply: true,
|
||||
}))?;
|
||||
/* flush KA */
|
||||
this.write_state.set(WriteState::FlushWal);
|
||||
}
|
||||
},
|
||||
WriteStateProj::ReadWal { fut } => {
|
||||
let read_len = ready!(fut.as_mut().poll(cx))?;
|
||||
assert!(read_len > 0, "read_len={}", read_len);
|
||||
|
||||
let mut this = self.as_mut().project();
|
||||
let write_ctx_clone = this.c.write_ctx.clone();
|
||||
let send_buf = &write_ctx_clone.send_buf.borrow()[..read_len];
|
||||
let chunk_end = this.c.start_pos + read_len as u64;
|
||||
// write data to the output buffer
|
||||
this.c
|
||||
.pgb
|
||||
.write_message(&BeMessage::XLogData(XLogDataBody {
|
||||
wal_start: this.c.start_pos.0,
|
||||
wal_end: chunk_end.0,
|
||||
timestamp: get_current_timestamp(),
|
||||
data: send_buf,
|
||||
}))
|
||||
.context("Failed to write XLogData")?;
|
||||
trace!("wrote a chunk of wal {}-{}", this.c.start_pos, chunk_end);
|
||||
this.c.start_pos = chunk_end;
|
||||
// and flush it
|
||||
this.write_state.set(WriteState::FlushWal);
|
||||
}
|
||||
WriteStateProj::FlushWal => {
|
||||
let this = self.as_mut().project();
|
||||
|
||||
let send_size = end_pos.checked_sub(start_pos).unwrap().0 as usize;
|
||||
let send_size = min(send_size, send_buf.len());
|
||||
|
||||
let send_buf = &mut send_buf[..send_size];
|
||||
|
||||
// read wal into buffer
|
||||
let send_size = wal_reader.read(send_buf).await?;
|
||||
let send_buf = &send_buf[..send_size];
|
||||
|
||||
// Write some data to the network socket.
|
||||
pgb.write_message(&BeMessage::XLogData(XLogDataBody {
|
||||
wal_start: start_pos.0,
|
||||
wal_end: end_pos.0,
|
||||
timestamp: get_current_timestamp(),
|
||||
data: send_buf,
|
||||
}))
|
||||
.context("Failed to send XLogData")?;
|
||||
|
||||
start_pos += send_size as u64;
|
||||
trace!("sent WAL up to {}", start_pos);
|
||||
ready!(this.c.pgb.poll_flush(cx))?;
|
||||
// If we are streaming to walproposer, check it is time to stop.
|
||||
if let Some(stop_pos) = this.c.stop_pos {
|
||||
if this.c.start_pos >= stop_pos {
|
||||
// recovery finished
|
||||
// TODO close the stream properly
|
||||
return Poll::Ready(Err(anyhow::anyhow!(format!(
|
||||
"ending streaming to walproposer at {}, receiver is caughtup and there is no computes",
|
||||
this.c.start_pos)).into()));
|
||||
}
|
||||
self.as_mut().start_read_wal();
|
||||
continue;
|
||||
} else {
|
||||
// if we don't know next portion is already available, wait
|
||||
// for it; otherwise proceed to sending
|
||||
if self.c.end_pos <= self.c.start_pos {
|
||||
self.as_mut().start_wait_wal();
|
||||
} else {
|
||||
self.as_mut().start_read_wal();
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
})
|
||||
// Start waiting for WAL, creating future doing that.
|
||||
fn start_wait_wal(self: Pin<&mut Self>) {
|
||||
let fut = self.c.wait_wal_fut();
|
||||
self.project().write_state.set(WriteState::WaitWal {
|
||||
fut: {
|
||||
// SAFETY: this function is the only way to assign WaitWal to
|
||||
// write_state. We just workaround impossibility of specifying
|
||||
// async fn type, which is anonymous.
|
||||
// transmute_copy is used as transmute refuses generic param:
|
||||
// https://users.rust-lang.org/t/transmute-doesnt-work-on-generic-types/87272
|
||||
assert_eq!(std::mem::size_of::<WF>(), std::mem::size_of_val(&fut));
|
||||
let t = unsafe { std::mem::transmute_copy(&fut) };
|
||||
std::mem::forget(fut);
|
||||
t
|
||||
},
|
||||
});
|
||||
}
|
||||
|
||||
// Switch into reading WAL state, creating Future doing that.
|
||||
fn start_read_wal(self: Pin<&mut Self>) {
|
||||
let fut = self.c.read_wal_fut();
|
||||
self.project().write_state.set(WriteState::ReadWal {
|
||||
fut: {
|
||||
// SAFETY: this function is the only way to assign ReadWal to
|
||||
// write_state. We just workaround impossibility of specifying
|
||||
// async fn type, which is anonymous.
|
||||
// transmute_copy is used as transmute refuses generic param:
|
||||
// https://users.rust-lang.org/t/transmute-doesnt-work-on-generic-types/87272
|
||||
assert_eq!(std::mem::size_of::<RF>(), std::mem::size_of_val(&fut));
|
||||
let t = unsafe { std::mem::transmute_copy(&fut) };
|
||||
std::mem::forget(fut);
|
||||
t
|
||||
},
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
impl ReplicationContext<'_> {
|
||||
// Create future waiting for WAL.
|
||||
fn wait_wal_fut(&self) -> impl Future<Output = anyhow::Result<Option<Lsn>>> {
|
||||
let mut commit_lsn_watch_rx = self.tli.get_commit_lsn_watch_rx();
|
||||
let start_pos = self.start_pos;
|
||||
async move { wait_for_lsn(&mut commit_lsn_watch_rx, start_pos).await }
|
||||
}
|
||||
|
||||
// Create future reading WAL.
|
||||
fn read_wal_fut(&self) -> impl Future<Output = anyhow::Result<usize>> {
|
||||
let mut send_size = self
|
||||
.end_pos
|
||||
.checked_sub(self.start_pos)
|
||||
.expect("reading wal without waiting for it first")
|
||||
.0 as usize;
|
||||
send_size = min(send_size, self.write_ctx.send_buf.borrow().len());
|
||||
let write_ctx_fut = self.write_ctx.clone();
|
||||
async move {
|
||||
let mut wal_reader_ref = write_ctx_fut.wal_reader.borrow_mut_send();
|
||||
let mut send_buf_ref = write_ctx_fut.send_buf.borrow_mut_send();
|
||||
|
||||
let send_buf = &mut send_buf_ref[..send_size];
|
||||
wal_reader_ref.read(send_buf).await
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
const POLL_STATE_TIMEOUT: Duration = Duration::from_secs(1);
|
||||
|
||||
// Wait until we have commit_lsn > lsn or timeout expires. Returns latest commit_lsn.
|
||||
async fn wait_for_lsn(rx: &mut Receiver<Lsn>, lsn: Lsn) -> Result<Option<Lsn>> {
|
||||
// Wait until we have commit_lsn > lsn or timeout expires. Returns
|
||||
// - Ok(Some(commit_lsn)) if needed lsn is successfully observed;
|
||||
// - Ok(None) if timeout expired;
|
||||
// - Err in case of error (if watch channel is in trouble, shouldn't happen).
|
||||
async fn wait_for_lsn(rx: &mut Receiver<Lsn>, lsn: Lsn) -> anyhow::Result<Option<Lsn>> {
|
||||
let commit_lsn: Lsn = *rx.borrow();
|
||||
if commit_lsn > lsn {
|
||||
return Ok(Some(commit_lsn));
|
||||
|
||||
@@ -346,7 +346,9 @@ impl WalBackupTask {
|
||||
backup_lsn, commit_lsn, e
|
||||
);
|
||||
|
||||
retry_attempt = retry_attempt.saturating_add(1);
|
||||
if retry_attempt < u32::MAX {
|
||||
retry_attempt += 1;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -385,7 +387,7 @@ async fn backup_single_segment(
|
||||
) -> Result<()> {
|
||||
let segment_file_path = seg.file_path(timeline_dir)?;
|
||||
let remote_segment_path = segment_file_path
|
||||
.strip_prefix(workspace_dir)
|
||||
.strip_prefix(&workspace_dir)
|
||||
.context("Failed to strip workspace dir prefix")
|
||||
.and_then(RemotePath::new)
|
||||
.with_context(|| {
|
||||
@@ -467,7 +469,7 @@ async fn backup_object(source_file: &Path, target_file: &RemotePath, size: usize
|
||||
pub async fn read_object(
|
||||
file_path: &RemotePath,
|
||||
offset: u64,
|
||||
) -> anyhow::Result<Pin<Box<dyn tokio::io::AsyncRead>>> {
|
||||
) -> anyhow::Result<Pin<Box<dyn tokio::io::AsyncRead + Send + Sync>>> {
|
||||
let storage = REMOTE_STORAGE
|
||||
.get()
|
||||
.context("Failed to get remote storage")?
|
||||
|
||||
@@ -2,36 +2,54 @@
|
||||
//! WAL service listens for client connections and
|
||||
//! receive WAL from wal_proposer and send it to WAL receivers
|
||||
//!
|
||||
use anyhow::Result;
|
||||
use anyhow::{Context, Result};
|
||||
use regex::Regex;
|
||||
use std::net::{TcpListener, TcpStream};
|
||||
use std::thread;
|
||||
use std::{future, thread};
|
||||
use tokio::net::TcpStream;
|
||||
use tracing::*;
|
||||
use utils::postgres_backend_async::QueryError;
|
||||
|
||||
use crate::handler::SafekeeperPostgresHandler;
|
||||
use crate::SafeKeeperConf;
|
||||
use utils::postgres_backend::{AuthType, PostgresBackend};
|
||||
use utils::postgres_backend_async::{AuthType, PostgresBackend};
|
||||
|
||||
/// Accept incoming TCP connections and spawn them into a background thread.
|
||||
pub fn thread_main(conf: SafeKeeperConf, listener: TcpListener) -> Result<()> {
|
||||
loop {
|
||||
match listener.accept() {
|
||||
Ok((socket, peer_addr)) => {
|
||||
debug!("accepted connection from {}", peer_addr);
|
||||
let conf = conf.clone();
|
||||
pub fn thread_main(conf: SafeKeeperConf, pg_listener: std::net::TcpListener) {
|
||||
let runtime = tokio::runtime::Builder::new_current_thread()
|
||||
.enable_all()
|
||||
.build()
|
||||
.context("create runtime")
|
||||
// todo catch error in main thread
|
||||
.expect("failed to create runtime");
|
||||
|
||||
let _ = thread::Builder::new()
|
||||
.name("WAL service thread".into())
|
||||
.spawn(move || {
|
||||
if let Err(err) = handle_socket(socket, conf) {
|
||||
error!("connection handler exited: {}", err);
|
||||
}
|
||||
})
|
||||
.unwrap();
|
||||
runtime
|
||||
.block_on(async move {
|
||||
// Tokio's from_std won't do this for us, per its comment.
|
||||
pg_listener.set_nonblocking(true)?;
|
||||
let listener = tokio::net::TcpListener::from_std(pg_listener)?;
|
||||
|
||||
loop {
|
||||
match listener.accept().await {
|
||||
Ok((socket, peer_addr)) => {
|
||||
debug!("accepted connection from {}", peer_addr);
|
||||
let conf = conf.clone();
|
||||
|
||||
let _ = thread::Builder::new()
|
||||
.name("WAL service thread".into())
|
||||
.spawn(move || {
|
||||
if let Err(err) = handle_socket(socket, conf) {
|
||||
error!("connection handler exited: {}", err);
|
||||
}
|
||||
})
|
||||
.unwrap();
|
||||
}
|
||||
Err(e) => error!("Failed to accept connection: {}", e),
|
||||
}
|
||||
}
|
||||
Err(e) => error!("Failed to accept connection: {}", e),
|
||||
}
|
||||
}
|
||||
#[allow(unreachable_code)] // hint compiler the closure return type
|
||||
Ok::<(), anyhow::Error>(())
|
||||
})
|
||||
.expect("listener failed")
|
||||
}
|
||||
|
||||
// Get unique thread id (Rust internal), with ThreadId removed for shorter printing
|
||||
@@ -44,9 +62,14 @@ fn get_tid() -> u64 {
|
||||
|
||||
/// This is run by `thread_main` above, inside a background thread.
|
||||
///
|
||||
fn handle_socket(socket: TcpStream, conf: SafeKeeperConf) -> Result<()> {
|
||||
fn handle_socket(mut socket: TcpStream, conf: SafeKeeperConf) -> Result<(), QueryError> {
|
||||
let _enter = info_span!("", tid = ?get_tid()).entered();
|
||||
|
||||
let runtime = tokio::runtime::Builder::new_current_thread()
|
||||
.enable_all()
|
||||
.build()?;
|
||||
let local = tokio::task::LocalSet::new();
|
||||
|
||||
socket.set_nodelay(true)?;
|
||||
|
||||
let auth_type = match conf.auth {
|
||||
@@ -54,9 +77,13 @@ fn handle_socket(socket: TcpStream, conf: SafeKeeperConf) -> Result<()> {
|
||||
Some(_) => AuthType::NeonJWT,
|
||||
};
|
||||
let mut conn_handler = SafekeeperPostgresHandler::new(conf);
|
||||
let pgbackend = PostgresBackend::new(socket, auth_type, None, false)?;
|
||||
// libpq replication protocol between safekeeper and replicas/pagers
|
||||
pgbackend.run(&mut conn_handler)?;
|
||||
let pgbackend = PostgresBackend::new(socket, auth_type, None)?;
|
||||
// libpq protocol between safekeeper and walproposer / pageserver
|
||||
// We don't use shutdown.
|
||||
local.block_on(
|
||||
&runtime,
|
||||
pgbackend.run(&mut conn_handler, || future::pending::<()>()),
|
||||
)?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
@@ -450,7 +450,7 @@ pub struct WalReader {
|
||||
timeline_dir: PathBuf,
|
||||
wal_seg_size: usize,
|
||||
pos: Lsn,
|
||||
wal_segment: Option<Pin<Box<dyn AsyncRead>>>,
|
||||
wal_segment: Option<Pin<Box<dyn AsyncRead + Send + Sync>>>,
|
||||
|
||||
// S3 will be used to read WAL if LSN is not available locally
|
||||
enable_remote_read: bool,
|
||||
@@ -491,6 +491,11 @@ impl WalReader {
|
||||
})
|
||||
}
|
||||
|
||||
pub async fn fake_read(&mut self) -> Result<usize> {
|
||||
tokio::time::sleep(std::time::Duration::from_millis(100)).await;
|
||||
Ok(self.pos.0 as usize)
|
||||
}
|
||||
|
||||
pub async fn read(&mut self, buf: &mut [u8]) -> Result<usize> {
|
||||
let mut wal_segment = match self.wal_segment.take() {
|
||||
Some(reader) => reader,
|
||||
@@ -517,7 +522,7 @@ impl WalReader {
|
||||
}
|
||||
|
||||
/// Open WAL segment at the current position of the reader.
|
||||
async fn open_segment(&self) -> Result<Pin<Box<dyn AsyncRead>>> {
|
||||
async fn open_segment(&self) -> Result<Pin<Box<dyn AsyncRead + Send + Sync>>> {
|
||||
let xlogoff = self.pos.segment_offset(self.wal_seg_size);
|
||||
let segno = self.pos.segment_number(self.wal_seg_size);
|
||||
let wal_file_name = XLogFileName(PG_TLI, segno, self.wal_seg_size);
|
||||
|
||||
@@ -40,10 +40,9 @@ def parse_metrics(text: str, name: str = "") -> Metrics:
|
||||
|
||||
|
||||
PAGESERVER_PER_TENANT_REMOTE_TIMELINE_CLIENT_METRICS: Tuple[str, ...] = (
|
||||
"pageserver_remote_upload_queue_unfinished_tasks",
|
||||
"pageserver_remote_operation_seconds_bucket",
|
||||
"pageserver_remote_operation_seconds_count",
|
||||
"pageserver_remote_operation_seconds_sum",
|
||||
"pageserver_remote_timeline_client_calls_unfinished",
|
||||
*[f"pageserver_remote_timeline_client_calls_started_{x}" for x in ["bucket", "count", "sum"]],
|
||||
*[f"pageserver_remote_operation_seconds_{x}" for x in ["bucket", "count", "sum"]],
|
||||
"pageserver_remote_physical_size",
|
||||
)
|
||||
|
||||
|
||||
@@ -18,6 +18,7 @@ from contextlib import closing, contextmanager
|
||||
from dataclasses import dataclass, field
|
||||
from enum import Flag, auto
|
||||
from functools import cached_property
|
||||
from itertools import chain, product
|
||||
from pathlib import Path
|
||||
from types import TracebackType
|
||||
from typing import Any, Dict, Iterator, List, Optional, Tuple, Type, Union, cast
|
||||
@@ -34,6 +35,7 @@ from _pytest.config import Config
|
||||
from _pytest.config.argparsing import Parser
|
||||
from _pytest.fixtures import FixtureRequest
|
||||
from fixtures.log_helper import log
|
||||
from fixtures.metrics import parse_metrics
|
||||
from fixtures.types import Lsn, TenantId, TimelineId
|
||||
from fixtures.utils import (
|
||||
ATTACHMENT_NAME_REGEX,
|
||||
@@ -595,6 +597,7 @@ class NeonEnvBuilder:
|
||||
rust_log_override: Optional[str] = None,
|
||||
default_branch_name: str = DEFAULT_BRANCH_NAME,
|
||||
preserve_database_files: bool = False,
|
||||
initial_tenant: Optional[TenantId] = None,
|
||||
):
|
||||
self.repo_dir = repo_dir
|
||||
self.rust_log_override = rust_log_override
|
||||
@@ -617,8 +620,9 @@ class NeonEnvBuilder:
|
||||
self.pg_distrib_dir = pg_distrib_dir
|
||||
self.pg_version = pg_version
|
||||
self.preserve_database_files = preserve_database_files
|
||||
self.initial_tenant = initial_tenant or TenantId.generate()
|
||||
|
||||
def init(self) -> NeonEnv:
|
||||
def init_configs(self) -> NeonEnv:
|
||||
# Cannot create more than one environment from one builder
|
||||
assert self.env is None, "environment already initialized"
|
||||
self.env = NeonEnv(self)
|
||||
@@ -629,8 +633,17 @@ class NeonEnvBuilder:
|
||||
self.env.start()
|
||||
|
||||
def init_start(self) -> NeonEnv:
|
||||
env = self.init()
|
||||
env = self.init_configs()
|
||||
self.start()
|
||||
|
||||
# Prepare the default branch to start the postgres on later.
|
||||
# Pageserver itself does not create tenants and timelines, until started first and asked via HTTP API.
|
||||
log.info(
|
||||
f"Services started, creating initial tenant {env.initial_tenant} and its initial timeline"
|
||||
)
|
||||
initial_tenant, initial_timeline = env.neon_cli.create_tenant(tenant_id=env.initial_tenant)
|
||||
log.info(f"Initial timeline {initial_tenant}/{initial_timeline} created successfully")
|
||||
|
||||
return env
|
||||
|
||||
def enable_remote_storage(
|
||||
@@ -889,12 +902,12 @@ class NeonEnv:
|
||||
|
||||
# generate initial tenant ID here instead of letting 'neon init' generate it,
|
||||
# so that we don't need to dig it out of the config file afterwards.
|
||||
self.initial_tenant = TenantId.generate()
|
||||
self.initial_tenant = config.initial_tenant
|
||||
|
||||
# Create a config file corresponding to the options
|
||||
toml = textwrap.dedent(
|
||||
f"""
|
||||
default_tenant_id = '{self.initial_tenant}'
|
||||
default_tenant_id = '{config.initial_tenant}'
|
||||
"""
|
||||
)
|
||||
|
||||
@@ -1409,6 +1422,33 @@ class PageserverHttpClient(requests.Session):
|
||||
]
|
||||
return sample.value
|
||||
|
||||
def get_remote_timeline_client_metric(
|
||||
self,
|
||||
metric_name: str,
|
||||
tenant_id: TenantId,
|
||||
timeline_id: TimelineId,
|
||||
file_kind: str,
|
||||
op_kind: str,
|
||||
) -> Optional[float]:
|
||||
metrics = parse_metrics(self.get_metrics(), "pageserver")
|
||||
matches = metrics.query_all(
|
||||
name=metric_name,
|
||||
filter={
|
||||
"tenant_id": str(tenant_id),
|
||||
"timeline_id": str(timeline_id),
|
||||
"file_kind": str(file_kind),
|
||||
"op_kind": str(op_kind),
|
||||
},
|
||||
)
|
||||
if len(matches) == 0:
|
||||
value = None
|
||||
elif len(matches) == 1:
|
||||
value = matches[0].value
|
||||
assert value is not None
|
||||
else:
|
||||
assert len(matches) < 2, "above filter should uniquely identify metric"
|
||||
return value
|
||||
|
||||
def get_metric_value(self, name: str) -> Optional[str]:
|
||||
metrics = self.get_metrics()
|
||||
relevant = [line for line in metrics.splitlines() if line.startswith(name)]
|
||||
@@ -1528,6 +1568,7 @@ class NeonCli(AbstractNeonCli):
|
||||
tenant_id: Optional[TenantId] = None,
|
||||
timeline_id: Optional[TimelineId] = None,
|
||||
conf: Optional[Dict[str, str]] = None,
|
||||
set_default: bool = False,
|
||||
) -> Tuple[TenantId, TimelineId]:
|
||||
"""
|
||||
Creates a new tenant, returns its id and its initial timeline's id.
|
||||
@@ -1536,47 +1577,51 @@ class NeonCli(AbstractNeonCli):
|
||||
tenant_id = TenantId.generate()
|
||||
if timeline_id is None:
|
||||
timeline_id = TimelineId.generate()
|
||||
if conf is None:
|
||||
res = self.raw_cli(
|
||||
[
|
||||
"tenant",
|
||||
"create",
|
||||
"--tenant-id",
|
||||
str(tenant_id),
|
||||
"--timeline-id",
|
||||
str(timeline_id),
|
||||
"--pg-version",
|
||||
self.env.pg_version,
|
||||
]
|
||||
)
|
||||
else:
|
||||
res = self.raw_cli(
|
||||
[
|
||||
"tenant",
|
||||
"create",
|
||||
"--tenant-id",
|
||||
str(tenant_id),
|
||||
"--timeline-id",
|
||||
str(timeline_id),
|
||||
"--pg-version",
|
||||
self.env.pg_version,
|
||||
]
|
||||
+ sum(list(map(lambda kv: (["-c", kv[0] + ":" + kv[1]]), conf.items())), [])
|
||||
|
||||
args = [
|
||||
"tenant",
|
||||
"create",
|
||||
"--tenant-id",
|
||||
str(tenant_id),
|
||||
"--timeline-id",
|
||||
str(timeline_id),
|
||||
"--pg-version",
|
||||
self.env.pg_version,
|
||||
]
|
||||
if conf is not None:
|
||||
args.extend(
|
||||
chain.from_iterable(
|
||||
product(["-c"], (f"{key}:{value}" for key, value in conf.items()))
|
||||
)
|
||||
)
|
||||
if set_default:
|
||||
args.append("--set-default")
|
||||
|
||||
res = self.raw_cli(args)
|
||||
res.check_returncode()
|
||||
return tenant_id, timeline_id
|
||||
|
||||
def set_default(self, tenant_id: TenantId):
|
||||
"""
|
||||
Update default tenant for future operations that require tenant_id.
|
||||
"""
|
||||
res = self.raw_cli(["tenant", "set-default", "--tenant-id", str(tenant_id)])
|
||||
res.check_returncode()
|
||||
|
||||
def config_tenant(self, tenant_id: TenantId, conf: Dict[str, str]):
|
||||
"""
|
||||
Update tenant config.
|
||||
"""
|
||||
if conf is None:
|
||||
res = self.raw_cli(["tenant", "config", "--tenant-id", str(tenant_id)])
|
||||
else:
|
||||
res = self.raw_cli(
|
||||
["tenant", "config", "--tenant-id", str(tenant_id)]
|
||||
+ sum(list(map(lambda kv: (["-c", kv[0] + ":" + kv[1]]), conf.items())), [])
|
||||
|
||||
args = ["tenant", "config", "--tenant-id", str(tenant_id)]
|
||||
if conf is not None:
|
||||
args.extend(
|
||||
chain.from_iterable(
|
||||
product(["-c"], (f"{key}:{value}" for key, value in conf.items()))
|
||||
)
|
||||
)
|
||||
|
||||
res = self.raw_cli(args)
|
||||
res.check_returncode()
|
||||
|
||||
def list_tenants(self) -> "subprocess.CompletedProcess[str]":
|
||||
@@ -1611,36 +1656,6 @@ class NeonCli(AbstractNeonCli):
|
||||
|
||||
return TimelineId(str(created_timeline_id))
|
||||
|
||||
def create_root_branch(
|
||||
self,
|
||||
branch_name: str,
|
||||
tenant_id: Optional[TenantId] = None,
|
||||
):
|
||||
cmd = [
|
||||
"timeline",
|
||||
"create",
|
||||
"--branch-name",
|
||||
branch_name,
|
||||
"--tenant-id",
|
||||
str(tenant_id or self.env.initial_tenant),
|
||||
"--pg-version",
|
||||
self.env.pg_version,
|
||||
]
|
||||
|
||||
res = self.raw_cli(cmd)
|
||||
res.check_returncode()
|
||||
|
||||
matches = CREATE_TIMELINE_ID_EXTRACTOR.search(res.stdout)
|
||||
|
||||
created_timeline_id = None
|
||||
if matches is not None:
|
||||
created_timeline_id = matches.group("timeline_id")
|
||||
|
||||
if created_timeline_id is None:
|
||||
raise Exception("could not find timeline id after `neon timeline create` invocation")
|
||||
else:
|
||||
return TimelineId(created_timeline_id)
|
||||
|
||||
def create_branch(
|
||||
self,
|
||||
new_branch_name: str = DEFAULT_BRANCH_NAME,
|
||||
@@ -1696,17 +1711,12 @@ class NeonCli(AbstractNeonCli):
|
||||
def init(
|
||||
self,
|
||||
config_toml: str,
|
||||
initial_timeline_id: Optional[TimelineId] = None,
|
||||
) -> "subprocess.CompletedProcess[str]":
|
||||
with tempfile.NamedTemporaryFile(mode="w+") as tmp:
|
||||
tmp.write(config_toml)
|
||||
tmp.flush()
|
||||
|
||||
cmd = ["init", f"--config={tmp.name}"]
|
||||
if initial_timeline_id:
|
||||
cmd.extend(["--timeline-id", str(initial_timeline_id)])
|
||||
|
||||
cmd.extend(["--pg-version", self.env.pg_version])
|
||||
cmd = ["init", f"--config={tmp.name}", "--pg-version", self.env.pg_version]
|
||||
|
||||
append_pageserver_param_overrides(
|
||||
params_to_update=cmd,
|
||||
@@ -1903,14 +1913,17 @@ class NeonPageserver(PgProtocol):
|
||||
".*wal receiver task finished with an error: walreceiver connection handling failure.*",
|
||||
".*Shutdown task error: walreceiver connection handling failure.*",
|
||||
".*wal_connection_manager.*tcp connect error: Connection refused.*",
|
||||
".*query handler for .* failed: Connection reset by peer.*",
|
||||
".*serving compute connection task.*exited with error: Broken pipe.*",
|
||||
".*Connection aborted: error communicating with the server: Broken pipe.*",
|
||||
".*Connection aborted: error communicating with the server: Transport endpoint is not connected.*",
|
||||
".*Connection aborted: error communicating with the server: Connection reset by peer.*",
|
||||
".*query handler for .* failed: Socket IO error: Connection reset by peer.*",
|
||||
".*serving compute connection task.*exited with error: Postgres connection error.*",
|
||||
".*serving compute connection task.*exited with error: Connection reset by peer.*",
|
||||
".*serving compute connection task.*exited with error: Postgres query error.*",
|
||||
".*Connection aborted: connection error: error communicating with the server: Broken pipe.*",
|
||||
".*Connection aborted: connection error: error communicating with the server: Transport endpoint is not connected.*",
|
||||
".*Connection aborted: connection error: error communicating with the server: Connection reset by peer.*",
|
||||
".*kill_and_wait_impl.*: wait successful.*",
|
||||
".*end streaming to Some.*",
|
||||
".*Replication stream finished: db error: ERROR: Socket IO error: end streaming to Some.*",
|
||||
".*query handler for 'pagestream.*failed: Broken pipe.*", # pageserver notices compute shut down
|
||||
".*query handler for 'pagestream.*failed: Connection reset by peer.*", # pageserver notices compute shut down
|
||||
# safekeeper connection can fail with this, in the window between timeline creation
|
||||
# and streaming start
|
||||
".*Failed to process query for timeline .*: state uninitialized, no data to read.*",
|
||||
|
||||
@@ -1,3 +1,5 @@
|
||||
import time
|
||||
|
||||
import pytest
|
||||
from fixtures.log_helper import log
|
||||
from fixtures.metrics import parse_metrics
|
||||
@@ -20,9 +22,19 @@ def httpserver_listen_address(port_distributor: PortDistributor):
|
||||
return ("localhost", port)
|
||||
|
||||
|
||||
num_metrics_received = 0
|
||||
initial_tenant = TenantId.generate()
|
||||
remote_uploaded = 0
|
||||
first_request = True
|
||||
checks = {
|
||||
"written_size": lambda value: value > 0,
|
||||
"resident_size": lambda value: value >= 0,
|
||||
# >= 0 check here is to avoid race condition when we receive metrics before
|
||||
# remote_uploaded is updated
|
||||
"remote_storage_size": lambda value: value > 0 if remote_uploaded > 0 else value >= 0,
|
||||
# logical size may lag behind the actual size, so allow 0 here
|
||||
"timeline_logical_size": lambda value: value >= 0,
|
||||
}
|
||||
|
||||
metric_kinds_checked = set([])
|
||||
|
||||
|
||||
#
|
||||
@@ -36,38 +48,19 @@ def metrics_handler(request: Request) -> Response:
|
||||
log.info("received events:")
|
||||
log.info(events)
|
||||
|
||||
checks = {
|
||||
"written_size": lambda value: value > 0,
|
||||
"resident_size": lambda value: value >= 0,
|
||||
# >= 0 check here is to avoid race condition when we receive metrics before
|
||||
# remote_uploaded is updated
|
||||
"remote_storage_size": lambda value: value > 0 if remote_uploaded > 0 else value >= 0,
|
||||
# logical size may lag behind the actual size, so allow 0 here
|
||||
"timeline_logical_size": lambda value: value >= 0,
|
||||
}
|
||||
|
||||
events_received = 0
|
||||
for event in events:
|
||||
check = checks.get(event["metric"])
|
||||
assert event["tenant_id"] == str(
|
||||
initial_tenant
|
||||
), "Expecting metrics only from the initial tenant"
|
||||
metric_name = event["metric"]
|
||||
|
||||
check = checks.get(metric_name)
|
||||
# calm down mypy
|
||||
if check is not None:
|
||||
assert check(event["value"]), f"{event['metric']} isn't valid"
|
||||
events_received += 1
|
||||
assert check(event["value"]), f"{metric_name} isn't valid"
|
||||
global metric_kinds_checked
|
||||
metric_kinds_checked.add(metric_name)
|
||||
|
||||
global first_request
|
||||
# check that all checks were sent
|
||||
# but only on the first request, because we don't send non-changed metrics
|
||||
if first_request:
|
||||
# we may receive more metrics than we check,
|
||||
# because there are two timelines
|
||||
# and we may receive per-timeline metrics from both
|
||||
# if the test was slow enough for these metrics to be collected
|
||||
# -1 because that is ok to not receive timeline_logical_size
|
||||
assert events_received >= len(checks) - 1
|
||||
first_request = False
|
||||
|
||||
global num_metrics_received
|
||||
num_metrics_received += 1
|
||||
return Response(status=200)
|
||||
|
||||
|
||||
@@ -83,11 +76,14 @@ def test_metric_collection(
|
||||
(host, port) = httpserver_listen_address
|
||||
metric_collection_endpoint = f"http://{host}:{port}/billing/api/v1/usage_events"
|
||||
|
||||
# Require collecting metrics frequently, since we change
|
||||
# the timeline and want something to be logged about it.
|
||||
#
|
||||
# Disable time-based pitr, we will use the manual GC calls
|
||||
# to trigger remote storage operations in a controlled way
|
||||
neon_env_builder.pageserver_config_override = (
|
||||
f"""
|
||||
metric_collection_interval="60s"
|
||||
metric_collection_interval="1s"
|
||||
metric_collection_endpoint="{metric_collection_endpoint}"
|
||||
"""
|
||||
+ "tenant_config={pitr_interval = '0 sec'}"
|
||||
@@ -100,6 +96,9 @@ def test_metric_collection(
|
||||
|
||||
log.info(f"test_metric_collection endpoint is {metric_collection_endpoint}")
|
||||
|
||||
# Set initial tenant of the test, that we expect the logs from
|
||||
global initial_tenant
|
||||
initial_tenant = neon_env_builder.initial_tenant
|
||||
# mock http server that returns OK for the metrics
|
||||
httpserver.expect_request("/billing/api/v1/usage_events", method="POST").respond_with_handler(
|
||||
metrics_handler
|
||||
@@ -154,7 +153,11 @@ def test_metric_collection(
|
||||
remote_uploaded = get_num_remote_ops("index", "upload")
|
||||
assert remote_uploaded > 0
|
||||
|
||||
# check that all requests are served
|
||||
# wait longer than collecting interval and check that all requests are served
|
||||
time.sleep(3)
|
||||
httpserver.check()
|
||||
global num_metrics_received
|
||||
assert num_metrics_received > 0, "no metrics were received"
|
||||
global metric_kinds_checked, checks
|
||||
expected_checks = set(checks.keys())
|
||||
assert len(metric_kinds_checked) == len(
|
||||
checks
|
||||
), f"Expected to receive and check all kind of metrics, but {expected_checks - metric_kinds_checked} got uncovered"
|
||||
|
||||
@@ -1,10 +1,17 @@
|
||||
from fixtures.neon_fixtures import NeonEnvBuilder
|
||||
from fixtures.neon_fixtures import NeonEnvBuilder, PortDistributor
|
||||
|
||||
|
||||
# Test that neon cli is able to start and stop all processes with the user defaults.
|
||||
# def test_neon_cli_basics(neon_simple_env: NeonEnv):
|
||||
def test_neon_cli_basics(neon_env_builder: NeonEnvBuilder):
|
||||
env = neon_env_builder.init()
|
||||
# Repeats the example from README.md as close as it can
|
||||
def test_neon_cli_basics(neon_env_builder: NeonEnvBuilder, port_distributor: PortDistributor):
|
||||
env = neon_env_builder.init_configs()
|
||||
# Skipping the init step that creates a local tenant in Pytest tests
|
||||
try:
|
||||
env.neon_cli.start()
|
||||
env.neon_cli.create_tenant(tenant_id=env.initial_tenant, set_default=True)
|
||||
env.neon_cli.pg_start(node_name="main", port=port_distributor.get_port())
|
||||
|
||||
env.neon_cli.start()
|
||||
env.neon_cli.stop()
|
||||
env.neon_cli.create_branch(new_branch_name="migration_check")
|
||||
env.neon_cli.pg_start(node_name="migration_check", port=port_distributor.get_port())
|
||||
finally:
|
||||
env.neon_cli.stop()
|
||||
|
||||
@@ -120,7 +120,7 @@ def test_ondemand_download_large_rel(
|
||||
|
||||
|
||||
#
|
||||
# If you have a relation with a long history of updates,the pageserver downloads the layer
|
||||
# If you have a relation with a long history of updates, the pageserver downloads the layer
|
||||
# files containing the history as needed by timetravel queries.
|
||||
#
|
||||
@pytest.mark.parametrize("remote_storage_kind", available_remote_storages())
|
||||
@@ -189,13 +189,10 @@ def test_ondemand_download_timetravel(
|
||||
# run checkpoint manually to be sure that data landed in remote storage
|
||||
client.timeline_checkpoint(tenant_id, timeline_id)
|
||||
|
||||
# wait until pageserver successfully uploaded a checkpoint to remote storage
|
||||
wait_for_upload(client, tenant_id, timeline_id, current_lsn)
|
||||
log.info("uploads have finished")
|
||||
|
||||
##### Stop the first pageserver instance, erase all its data
|
||||
env.postgres.stop_all()
|
||||
|
||||
# wait until pageserver has successfully uploaded all the data to remote storage
|
||||
wait_for_sk_commit_lsn_to_reach_remote_storage(
|
||||
tenant_id, timeline_id, env.safekeepers, env.pageserver
|
||||
)
|
||||
@@ -227,11 +224,15 @@ def test_ondemand_download_timetravel(
|
||||
|
||||
wait_until(10, 0.2, lambda: assert_tenant_status(client, tenant_id, "Active"))
|
||||
|
||||
# current_physical_size reports sum of layer file sizes, regardless of local or remote
|
||||
# The current_physical_size reports the sum of layers loaded in the layer
|
||||
# map, regardless of where the layer files are located. So even though we
|
||||
# just removed the local files, they still count towards
|
||||
# current_physical_size because they are loaded as `RemoteLayer`s.
|
||||
assert filled_current_physical == get_api_current_physical_size()
|
||||
|
||||
# Run queries at different points in time
|
||||
num_layers_downloaded = [0]
|
||||
physical_size = [get_resident_physical_size()]
|
||||
resident_size = [get_resident_physical_size()]
|
||||
for (checkpoint_number, lsn) in lsns:
|
||||
pg_old = env.postgres.create_start(
|
||||
branch_name="main", node_name=f"test_old_lsn_{checkpoint_number}", lsn=lsn
|
||||
@@ -268,13 +269,15 @@ def test_ondemand_download_timetravel(
|
||||
if len(num_layers_downloaded) > 4:
|
||||
assert after_downloads > num_layers_downloaded[len(num_layers_downloaded) - 4]
|
||||
|
||||
# Likewise, assert that the physical_size metric grows as layers are downloaded
|
||||
physical_size.append(get_resident_physical_size())
|
||||
log.info(f"physical_size[-1]={physical_size[-1]}")
|
||||
if len(physical_size) > 4:
|
||||
assert physical_size[-1] > physical_size[len(physical_size) - 4]
|
||||
# Likewise, assert that the resident_physical_size metric grows as layers are downloaded
|
||||
resident_size.append(get_resident_physical_size())
|
||||
log.info(f"resident_size[-1]={resident_size[-1]}")
|
||||
if len(resident_size) > 4:
|
||||
assert resident_size[-1] > resident_size[len(resident_size) - 4]
|
||||
|
||||
# current_physical_size reports sum of layer file sizes, regardless of local or remote
|
||||
# current_physical_size reports the total size of all layer files, whether
|
||||
# they are present only in the remote storage, only locally, or both.
|
||||
# It should not change.
|
||||
assert filled_current_physical == get_api_current_physical_size()
|
||||
|
||||
|
||||
|
||||
@@ -12,11 +12,9 @@ def test_pageserver_recovery(neon_env_builder: NeonEnvBuilder):
|
||||
# Override default checkpointer settings to run it more often
|
||||
neon_env_builder.pageserver_config_override = "tenant_config={checkpoint_distance = 1048576}"
|
||||
|
||||
env = neon_env_builder.init()
|
||||
env = neon_env_builder.init_start()
|
||||
env.pageserver.is_testing_enabled_or_skip()
|
||||
|
||||
neon_env_builder.start()
|
||||
|
||||
# These warnings are expected, when the pageserver is restarted abruptly
|
||||
env.pageserver.allowed_errors.append(".*found future delta layer.*")
|
||||
env.pageserver.allowed_errors.append(".*found future image layer.*")
|
||||
|
||||
@@ -2,11 +2,11 @@
|
||||
# env NEON_PAGESERVER_OVERRIDES="remote_storage={local_path='/tmp/neon_zzz/'}" poetry ......
|
||||
|
||||
import os
|
||||
import re
|
||||
import shutil
|
||||
import threading
|
||||
import time
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Tuple
|
||||
|
||||
import pytest
|
||||
from fixtures.log_helper import log
|
||||
@@ -271,14 +271,15 @@ def test_remote_storage_upload_queue_retries(
|
||||
wait_for_last_flush_lsn(env, pg, tenant_id, timeline_id)
|
||||
|
||||
def get_queued_count(file_kind, op_kind):
|
||||
metrics = client.get_metrics()
|
||||
matches = re.search(
|
||||
f'^pageserver_remote_upload_queue_unfinished_tasks{{file_kind="{file_kind}",op_kind="{op_kind}",tenant_id="{tenant_id}",timeline_id="{timeline_id}"}} (\\S+)$',
|
||||
metrics,
|
||||
re.MULTILINE,
|
||||
val = client.get_remote_timeline_client_metric(
|
||||
"pageserver_remote_timeline_client_calls_unfinished",
|
||||
tenant_id,
|
||||
timeline_id,
|
||||
file_kind,
|
||||
op_kind,
|
||||
)
|
||||
assert matches
|
||||
return int(matches[1])
|
||||
assert val is not None, "expecting metric to be present"
|
||||
return int(val)
|
||||
|
||||
# create some layers & wait for uploads to finish
|
||||
overwrite_data_and_wait_for_it_to_arrive_at_pageserver("a")
|
||||
@@ -368,6 +369,168 @@ def test_remote_storage_upload_queue_retries(
|
||||
assert query_scalar(cur, "SELECT COUNT(*) FROM foo WHERE val = 'd'") == 10000
|
||||
|
||||
|
||||
@pytest.mark.parametrize("remote_storage_kind", [RemoteStorageKind.LOCAL_FS])
|
||||
def test_remote_timeline_client_calls_started_metric(
|
||||
neon_env_builder: NeonEnvBuilder,
|
||||
remote_storage_kind: RemoteStorageKind,
|
||||
):
|
||||
|
||||
neon_env_builder.enable_remote_storage(
|
||||
remote_storage_kind=remote_storage_kind,
|
||||
test_name="test_remote_timeline_client_metrics",
|
||||
)
|
||||
|
||||
env = neon_env_builder.init_start()
|
||||
|
||||
# create tenant with config that will determinstically allow
|
||||
# compaction and gc
|
||||
tenant_id, timeline_id = env.neon_cli.create_tenant(
|
||||
conf={
|
||||
# small checkpointing and compaction targets to ensure we generate many upload operations
|
||||
"checkpoint_distance": f"{128 * 1024}",
|
||||
"compaction_threshold": "1",
|
||||
"compaction_target_size": f"{128 * 1024}",
|
||||
# no PITR horizon, we specify the horizon when we request on-demand GC
|
||||
"pitr_interval": "0s",
|
||||
# disable background compaction and GC. We invoke it manually when we want it to happen.
|
||||
"gc_period": "0s",
|
||||
"compaction_period": "0s",
|
||||
# don't create image layers, that causes just noise
|
||||
"image_creation_threshold": "10000",
|
||||
}
|
||||
)
|
||||
|
||||
client = env.pageserver.http_client()
|
||||
|
||||
pg = env.postgres.create_start("main", tenant_id=tenant_id)
|
||||
|
||||
pg.safe_psql("CREATE TABLE foo (id INTEGER PRIMARY KEY, val text)")
|
||||
|
||||
def overwrite_data_and_wait_for_it_to_arrive_at_pageserver(data):
|
||||
# create initial set of layers & upload them with failpoints configured
|
||||
pg.safe_psql_many(
|
||||
[
|
||||
f"""
|
||||
INSERT INTO foo (id, val)
|
||||
SELECT g, '{data}'
|
||||
FROM generate_series(1, 10000) g
|
||||
ON CONFLICT (id) DO UPDATE
|
||||
SET val = EXCLUDED.val
|
||||
""",
|
||||
# to ensure that GC can actually remove some layers
|
||||
"VACUUM foo",
|
||||
]
|
||||
)
|
||||
wait_for_last_flush_lsn(env, pg, tenant_id, timeline_id)
|
||||
|
||||
def get_queued_count(file_kind, op_kind):
|
||||
val = client.get_remote_timeline_client_metric(
|
||||
"pageserver_remote_timeline_client_calls_unfinished",
|
||||
tenant_id,
|
||||
timeline_id,
|
||||
file_kind,
|
||||
op_kind,
|
||||
)
|
||||
if val is None:
|
||||
return val
|
||||
return int(val)
|
||||
|
||||
def wait_upload_queue_empty():
|
||||
wait_until(2, 1, lambda: get_queued_count(file_kind="layer", op_kind="upload") == 0)
|
||||
wait_until(2, 1, lambda: get_queued_count(file_kind="index", op_kind="upload") == 0)
|
||||
wait_until(2, 1, lambda: get_queued_count(file_kind="layer", op_kind="delete") == 0)
|
||||
|
||||
calls_started: Dict[Tuple[str, str], List[int]] = {
|
||||
("layer", "upload"): [0],
|
||||
("index", "upload"): [0],
|
||||
("layer", "delete"): [0],
|
||||
}
|
||||
|
||||
def fetch_calls_started():
|
||||
for (file_kind, op_kind), observations in calls_started.items():
|
||||
val = client.get_remote_timeline_client_metric(
|
||||
"pageserver_remote_timeline_client_calls_started_count",
|
||||
tenant_id,
|
||||
timeline_id,
|
||||
file_kind,
|
||||
op_kind,
|
||||
)
|
||||
assert val is not None, f"expecting metric to be present: {file_kind} {op_kind}"
|
||||
val = int(val)
|
||||
observations.append(val)
|
||||
|
||||
def ensure_calls_started_grew():
|
||||
for (file_kind, op_kind), observations in calls_started.items():
|
||||
log.info(f"ensure_calls_started_grew: {file_kind} {op_kind}: {observations}")
|
||||
assert all(
|
||||
x < y for x, y in zip(observations, observations[1:])
|
||||
), f"observations for {file_kind} {op_kind} did not grow monotonically: {observations}"
|
||||
|
||||
def churn(data_pass1, data_pass2):
|
||||
overwrite_data_and_wait_for_it_to_arrive_at_pageserver(data_pass1)
|
||||
client.timeline_checkpoint(tenant_id, timeline_id)
|
||||
client.timeline_compact(tenant_id, timeline_id)
|
||||
overwrite_data_and_wait_for_it_to_arrive_at_pageserver(data_pass2)
|
||||
client.timeline_checkpoint(tenant_id, timeline_id)
|
||||
client.timeline_compact(tenant_id, timeline_id)
|
||||
gc_result = client.timeline_gc(tenant_id, timeline_id, 0)
|
||||
print_gc_result(gc_result)
|
||||
assert gc_result["layers_removed"] > 0
|
||||
|
||||
# create some layers & wait for uploads to finish
|
||||
churn("a", "b")
|
||||
|
||||
wait_upload_queue_empty()
|
||||
|
||||
# ensure that we updated the calls_started metric
|
||||
fetch_calls_started()
|
||||
ensure_calls_started_grew()
|
||||
|
||||
# more churn to cause more operations
|
||||
churn("c", "d")
|
||||
|
||||
# ensure that the calls_started metric continued to be updated
|
||||
fetch_calls_started()
|
||||
ensure_calls_started_grew()
|
||||
|
||||
### now we exercise the download path
|
||||
calls_started.clear()
|
||||
calls_started.update(
|
||||
{
|
||||
("index", "download"): [0],
|
||||
("layer", "download"): [0],
|
||||
}
|
||||
)
|
||||
|
||||
env.pageserver.stop(immediate=True)
|
||||
env.postgres.stop_all()
|
||||
|
||||
dir_to_clear = Path(env.repo_dir) / "tenants"
|
||||
shutil.rmtree(dir_to_clear)
|
||||
os.mkdir(dir_to_clear)
|
||||
|
||||
env.pageserver.start()
|
||||
client = env.pageserver.http_client()
|
||||
|
||||
client.tenant_attach(tenant_id)
|
||||
|
||||
def tenant_active():
|
||||
all_states = client.tenant_list()
|
||||
[tenant] = [t for t in all_states if TenantId(t["id"]) == tenant_id]
|
||||
assert tenant["state"] == "Active"
|
||||
|
||||
wait_until(30, 1, tenant_active)
|
||||
|
||||
log.info("restarting postgres to validate")
|
||||
pg = env.postgres.create_start("main", tenant_id=tenant_id)
|
||||
with pg.cursor() as cur:
|
||||
assert query_scalar(cur, "SELECT COUNT(*) FROM foo WHERE val = 'd'") == 10000
|
||||
|
||||
# ensure that we updated the calls_started download metric
|
||||
fetch_calls_started()
|
||||
ensure_calls_started_grew()
|
||||
|
||||
|
||||
# Test that we correctly handle timeline with layers stuck in upload queue
|
||||
@pytest.mark.parametrize("remote_storage_kind", [RemoteStorageKind.LOCAL_FS])
|
||||
def test_timeline_deletion_with_files_stuck_in_upload_queue(
|
||||
@@ -401,15 +564,14 @@ def test_timeline_deletion_with_files_stuck_in_upload_queue(
|
||||
client = env.pageserver.http_client()
|
||||
|
||||
def get_queued_count(file_kind, op_kind):
|
||||
metrics = client.get_metrics()
|
||||
matches = re.search(
|
||||
f'^pageserver_remote_upload_queue_unfinished_tasks{{file_kind="{file_kind}",op_kind="{op_kind}",tenant_id="{tenant_id}",timeline_id="{timeline_id}"}} (\\S+)$',
|
||||
metrics,
|
||||
re.MULTILINE,
|
||||
val = client.get_remote_timeline_client_metric(
|
||||
"pageserver_remote_timeline_client_calls_unfinished",
|
||||
tenant_id,
|
||||
timeline_id,
|
||||
file_kind,
|
||||
op_kind,
|
||||
)
|
||||
if matches is None:
|
||||
return None
|
||||
return int(matches[1])
|
||||
return int(val) if val is not None else val
|
||||
|
||||
pg = env.postgres.create_start("main", tenant_id=tenant_id)
|
||||
|
||||
|
||||
@@ -1,9 +1,13 @@
|
||||
import asyncio
|
||||
import random
|
||||
import time
|
||||
from threading import Thread
|
||||
|
||||
import asyncpg
|
||||
import pytest
|
||||
from fixtures.log_helper import log
|
||||
from fixtures.neon_fixtures import (
|
||||
NeonEnv,
|
||||
NeonEnvBuilder,
|
||||
PageserverApiException,
|
||||
PageserverHttpClient,
|
||||
@@ -12,6 +16,7 @@ from fixtures.neon_fixtures import (
|
||||
available_remote_storages,
|
||||
wait_for_last_record_lsn,
|
||||
wait_for_upload,
|
||||
wait_until,
|
||||
wait_until_tenant_state,
|
||||
)
|
||||
from fixtures.types import Lsn, TenantId, TimelineId
|
||||
@@ -84,6 +89,150 @@ def test_tenant_reattach(
|
||||
assert env.pageserver.log_contains(".*download.*failed, will retry.*")
|
||||
|
||||
|
||||
num_connections = 10
|
||||
num_rows = 100000
|
||||
updates_to_perform = 0
|
||||
|
||||
updates_started = 0
|
||||
updates_finished = 0
|
||||
|
||||
|
||||
# Run random UPDATEs on test table. On failure, try again.
|
||||
async def update_table(pg_conn: asyncpg.Connection):
|
||||
global updates_started, updates_finished, updates_to_perform
|
||||
|
||||
while updates_started < updates_to_perform or updates_to_perform == 0:
|
||||
updates_started += 1
|
||||
id = random.randrange(1, num_rows)
|
||||
|
||||
# Loop to retry until the UPDATE succeeds
|
||||
while True:
|
||||
try:
|
||||
await pg_conn.fetchrow(f"UPDATE t SET counter = counter + 1 WHERE id = {id}")
|
||||
updates_finished += 1
|
||||
if updates_finished % 1000 == 0:
|
||||
log.info(f"update {updates_finished} / {updates_to_perform}")
|
||||
break
|
||||
except asyncpg.PostgresError as e:
|
||||
# Received error from Postgres. Log it, sleep a little, and continue
|
||||
log.info(f"UPDATE error: {e}")
|
||||
await asyncio.sleep(0.1)
|
||||
|
||||
|
||||
async def sleep_and_reattach(pageserver_http: PageserverHttpClient, tenant_id: TenantId):
|
||||
global updates_started, updates_finished, updates_to_perform
|
||||
|
||||
# Wait until we have performed some updates
|
||||
wait_until(20, 0.5, lambda: updates_finished > 500)
|
||||
|
||||
log.info("Detaching tenant")
|
||||
pageserver_http.tenant_detach(tenant_id)
|
||||
await asyncio.sleep(1)
|
||||
log.info("Re-attaching tenant")
|
||||
pageserver_http.tenant_attach(tenant_id)
|
||||
log.info("Re-attach finished")
|
||||
|
||||
# Continue with 5000 more updates
|
||||
updates_to_perform = updates_started + 5000
|
||||
|
||||
|
||||
# async guts of test_tenant_reattach_while_bysy test
|
||||
async def reattach_while_busy(
|
||||
env: NeonEnv, pg: Postgres, pageserver_http: PageserverHttpClient, tenant_id: TenantId
|
||||
):
|
||||
workers = []
|
||||
for worker_id in range(num_connections):
|
||||
pg_conn = await pg.connect_async()
|
||||
workers.append(asyncio.create_task(update_table(pg_conn)))
|
||||
|
||||
workers.append(asyncio.create_task(sleep_and_reattach(pageserver_http, tenant_id)))
|
||||
await asyncio.gather(*workers)
|
||||
|
||||
assert updates_finished == updates_to_perform
|
||||
|
||||
|
||||
# Detach and re-attach tenant, while compute is busy running queries.
|
||||
#
|
||||
# Some of the queries may fail, in the window that the tenant has been
|
||||
# detached but not yet re-attached. But Postgres itself should keep
|
||||
# running, and when we retry the queries, they should start working
|
||||
# after the attach has finished.
|
||||
|
||||
# FIXME:
|
||||
#
|
||||
# This is pretty unstable at the moment. I've seen it fail with a warning like this:
|
||||
#
|
||||
# AssertionError: assert not ['2023-01-05T13:09:40.708303Z WARN remote_upload{tenant=c3fc41f6cf29a7626b90316e3518cd4b timeline=7978246f85faa71ab03...1282b/000000000000000000000000000000000000-FFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFF__0000000001716699-0000000001736681"\n']
|
||||
#
|
||||
# (https://neon-github-public-dev.s3.amazonaws.com/reports/pr-3232/debug/3846817847/index.html#suites/f9eba3cfdb71aa6e2b54f6466222829b/470fc62b5db7d7d7/)
|
||||
# I believe that failure happened because there is a race condition
|
||||
# between detach and starting remote upload tasks:
|
||||
#
|
||||
# 1. detach_timeline calls task_mgr::shutdown_tasks(), sending shutdown
|
||||
# signal to all in-progress tasks associated with the tenant.
|
||||
# 2. Just after shutdown_tasks() has collected the list of tasks,
|
||||
# a new remote-upload task is spawned.
|
||||
#
|
||||
# See https://github.com/neondatabase/neon/issues/3273
|
||||
#
|
||||
#
|
||||
# I also saw this failure:
|
||||
#
|
||||
# test_runner/regress/test_tenant_detach.py:194: in test_tenant_reattach_while_busy
|
||||
# asyncio.run(reattach_while_busy(env, pg, pageserver_http, tenant_id))
|
||||
# /home/nonroot/.pyenv/versions/3.9.2/lib/python3.9/asyncio/runners.py:44: in run
|
||||
# return loop.run_until_complete(main)
|
||||
# /home/nonroot/.pyenv/versions/3.9.2/lib/python3.9/asyncio/base_events.py:642: in run_until_complete
|
||||
# return future.result()
|
||||
# test_runner/regress/test_tenant_detach.py:151: in reattach_while_busy
|
||||
# assert updates_finished == updates_to_perform
|
||||
# E assert 5010 == 10010
|
||||
# E +5010
|
||||
# E -10010
|
||||
#
|
||||
# I don't know what's causing that...
|
||||
@pytest.mark.skip(reason="fixme")
|
||||
@pytest.mark.parametrize("remote_storage_kind", available_remote_storages())
|
||||
def test_tenant_reattach_while_busy(
|
||||
neon_env_builder: NeonEnvBuilder,
|
||||
remote_storage_kind: RemoteStorageKind,
|
||||
):
|
||||
neon_env_builder.enable_remote_storage(
|
||||
remote_storage_kind=remote_storage_kind,
|
||||
test_name="test_tenant_reattach_while_busy",
|
||||
)
|
||||
env = neon_env_builder.init_start()
|
||||
|
||||
# Attempts to connect from compute to pageserver while the tenant is
|
||||
# temporarily detached produces these errors in the pageserver log.
|
||||
env.pageserver.allowed_errors.append(".*Tenant .* not found in the local state.*")
|
||||
env.pageserver.allowed_errors.append(
|
||||
".*Tenant .* will not become active\\. Current state: Stopping.*"
|
||||
)
|
||||
|
||||
pageserver_http = env.pageserver.http_client()
|
||||
|
||||
# create new nenant
|
||||
tenant_id, timeline_id = env.neon_cli.create_tenant(
|
||||
# Create layers aggressively
|
||||
conf={"checkpoint_distance": "100000"}
|
||||
)
|
||||
|
||||
pg = env.postgres.create_start("main", tenant_id=tenant_id)
|
||||
|
||||
cur = pg.connect().cursor()
|
||||
|
||||
cur.execute("CREATE TABLE t(id int primary key, counter int)")
|
||||
cur.execute(f"INSERT INTO t SELECT generate_series(1,{num_rows}), 0")
|
||||
|
||||
# Run the test
|
||||
asyncio.run(reattach_while_busy(env, pg, pageserver_http, tenant_id))
|
||||
|
||||
# Verify table contents
|
||||
assert query_scalar(cur, "SELECT count(*) FROM t") == num_rows
|
||||
assert query_scalar(cur, "SELECT sum(counter) FROM t") == updates_to_perform
|
||||
|
||||
|
||||
def test_tenant_detach_smoke(neon_env_builder: NeonEnvBuilder):
|
||||
env = neon_env_builder.init_start()
|
||||
pageserver_http = env.pageserver.http_client()
|
||||
|
||||
@@ -1105,7 +1105,6 @@ def test_delete_force(neon_env_builder: NeonEnvBuilder, auth_enabled: bool):
|
||||
env.pageserver.allowed_errors.extend(
|
||||
[
|
||||
".*Failed to process query for timeline .*: Timeline .* was not found in global map.*",
|
||||
".*end streaming to Some.*",
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
@@ -40,7 +40,7 @@ scopeguard = { version = "1", features = ["use_std"] }
|
||||
serde = { version = "1", features = ["alloc", "derive", "serde_derive", "std"] }
|
||||
serde_json = { version = "1", features = ["raw_value", "std"] }
|
||||
socket2 = { version = "0.4", default-features = false, features = ["all"] }
|
||||
tokio = { version = "1", features = ["bytes", "fs", "io-std", "io-util", "libc", "macros", "memchr", "mio", "net", "num_cpus", "once_cell", "process", "rt", "rt-multi-thread", "signal-hook-registry", "socket2", "sync", "time", "tokio-macros"] }
|
||||
tokio = { version = "1", features = ["bytes", "fs", "io-std", "io-util", "libc", "macros", "memchr", "mio", "net", "num_cpus", "process", "rt", "rt-multi-thread", "signal-hook-registry", "socket2", "sync", "time", "tokio-macros"] }
|
||||
tokio-util = { version = "0.7", features = ["codec", "io", "io-util", "tracing"] }
|
||||
tower = { version = "0.4", features = ["__common", "balance", "buffer", "discover", "futures-core", "futures-util", "indexmap", "limit", "load", "log", "make", "pin-project", "pin-project-lite", "rand", "ready-cache", "retry", "slab", "timeout", "tokio", "tokio-util", "tracing", "util"] }
|
||||
tracing = { version = "0.1", features = ["attributes", "log", "std", "tracing-attributes"] }
|
||||
|
||||
Reference in New Issue
Block a user