mirror of
https://github.com/neondatabase/neon.git
synced 2026-05-24 08:30:37 +00:00
Compare commits
4 Commits
jcsp/storc
...
installed_
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
8d7fa802f4 | ||
|
|
811e5675a3 | ||
|
|
b9c6cd28a4 | ||
|
|
bd722f24ae |
@@ -43,8 +43,7 @@ runs:
|
||||
PR_NUMBER=$(jq --raw-output .pull_request.number "$GITHUB_EVENT_PATH" || true)
|
||||
if [ "${PR_NUMBER}" != "null" ]; then
|
||||
BRANCH_OR_PR=pr-${PR_NUMBER}
|
||||
elif [ "${GITHUB_REF_NAME}" = "main" ] || [ "${GITHUB_REF_NAME}" = "release" ] || \
|
||||
[ "${GITHUB_REF_NAME}" = "release-proxy" ] || [ "${GITHUB_REF_NAME}" = "release-compute" ]; then
|
||||
elif [ "${GITHUB_REF_NAME}" = "main" ] || [ "${GITHUB_REF_NAME}" = "release" ] || [ "${GITHUB_REF_NAME}" = "release-proxy" ]; then
|
||||
# Shortcut for special branches
|
||||
BRANCH_OR_PR=${GITHUB_REF_NAME}
|
||||
else
|
||||
|
||||
@@ -23,8 +23,7 @@ runs:
|
||||
PR_NUMBER=$(jq --raw-output .pull_request.number "$GITHUB_EVENT_PATH" || true)
|
||||
if [ "${PR_NUMBER}" != "null" ]; then
|
||||
BRANCH_OR_PR=pr-${PR_NUMBER}
|
||||
elif [ "${GITHUB_REF_NAME}" = "main" ] || [ "${GITHUB_REF_NAME}" = "release" ] || \
|
||||
[ "${GITHUB_REF_NAME}" = "release-proxy" ] || [ "${GITHUB_REF_NAME}" = "release-compute" ]; then
|
||||
elif [ "${GITHUB_REF_NAME}" = "main" ] || [ "${GITHUB_REF_NAME}" = "release" ] || [ "${GITHUB_REF_NAME}" = "release-proxy" ]; then
|
||||
# Shortcut for special branches
|
||||
BRANCH_OR_PR=${GITHUB_REF_NAME}
|
||||
else
|
||||
|
||||
2
.github/workflows/_create-release-pr.yml
vendored
2
.github/workflows/_create-release-pr.yml
vendored
@@ -21,7 +21,7 @@ defaults:
|
||||
shell: bash -euo pipefail {0}
|
||||
|
||||
jobs:
|
||||
create-release-branch:
|
||||
create-storage-release-branch:
|
||||
runs-on: ubuntu-22.04
|
||||
|
||||
permissions:
|
||||
|
||||
23
.github/workflows/build_and_test.yml
vendored
23
.github/workflows/build_and_test.yml
vendored
@@ -6,7 +6,6 @@ on:
|
||||
- main
|
||||
- release
|
||||
- release-proxy
|
||||
- release-compute
|
||||
pull_request:
|
||||
|
||||
defaults:
|
||||
@@ -71,10 +70,8 @@ jobs:
|
||||
echo "tag=release-$(git rev-list --count HEAD)" >> $GITHUB_OUTPUT
|
||||
elif [[ "$GITHUB_REF_NAME" == "release-proxy" ]]; then
|
||||
echo "tag=release-proxy-$(git rev-list --count HEAD)" >> $GITHUB_OUTPUT
|
||||
elif [[ "$GITHUB_REF_NAME" == "release-compute" ]]; then
|
||||
echo "tag=release-compute-$(git rev-list --count HEAD)" >> $GITHUB_OUTPUT
|
||||
else
|
||||
echo "GITHUB_REF_NAME (value '$GITHUB_REF_NAME') is not set to either 'main' or 'release', 'release-proxy', 'release-compute'"
|
||||
echo "GITHUB_REF_NAME (value '$GITHUB_REF_NAME') is not set to either 'main' or 'release'"
|
||||
echo "tag=$GITHUB_RUN_ID" >> $GITHUB_OUTPUT
|
||||
fi
|
||||
shell: bash
|
||||
@@ -516,7 +513,7 @@ jobs:
|
||||
})
|
||||
|
||||
trigger-e2e-tests:
|
||||
if: ${{ !github.event.pull_request.draft || contains( github.event.pull_request.labels.*.name, 'run-e2e-tests-in-draft') || github.ref_name == 'main' || github.ref_name == 'release' || github.ref_name == 'release-proxy' || github.ref_name == 'release-compute' }}
|
||||
if: ${{ !github.event.pull_request.draft || contains( github.event.pull_request.labels.*.name, 'run-e2e-tests-in-draft') || github.ref_name == 'main' || github.ref_name == 'release' || github.ref_name == 'release-proxy' }}
|
||||
needs: [ check-permissions, promote-images, tag ]
|
||||
uses: ./.github/workflows/trigger-e2e-tests.yml
|
||||
secrets: inherit
|
||||
@@ -937,7 +934,7 @@ jobs:
|
||||
neondatabase/neon-test-extensions-v16:${{ needs.tag.outputs.build-tag }}
|
||||
|
||||
- name: Configure AWS-prod credentials
|
||||
if: github.ref_name == 'release'|| github.ref_name == 'release-proxy' || github.ref_name == 'release-compute'
|
||||
if: github.ref_name == 'release'|| github.ref_name == 'release-proxy'
|
||||
uses: aws-actions/configure-aws-credentials@v4
|
||||
with:
|
||||
aws-region: eu-central-1
|
||||
@@ -946,12 +943,12 @@ jobs:
|
||||
|
||||
- name: Login to prod ECR
|
||||
uses: docker/login-action@v3
|
||||
if: github.ref_name == 'release'|| github.ref_name == 'release-proxy' || github.ref_name == 'release-compute'
|
||||
if: github.ref_name == 'release'|| github.ref_name == 'release-proxy'
|
||||
with:
|
||||
registry: 093970136003.dkr.ecr.eu-central-1.amazonaws.com
|
||||
|
||||
- name: Copy all images to prod ECR
|
||||
if: github.ref_name == 'release' || github.ref_name == 'release-proxy' || github.ref_name == 'release-compute'
|
||||
if: github.ref_name == 'release'|| github.ref_name == 'release-proxy'
|
||||
run: |
|
||||
for image in neon compute-tools {vm-,}compute-node-{v14,v15,v16,v17}; do
|
||||
docker buildx imagetools create -t 093970136003.dkr.ecr.eu-central-1.amazonaws.com/${image}:${{ needs.tag.outputs.build-tag }} \
|
||||
@@ -971,7 +968,7 @@ jobs:
|
||||
tenant_id: ${{ vars.AZURE_TENANT_ID }}
|
||||
|
||||
push-to-acr-prod:
|
||||
if: github.ref_name == 'release' || github.ref_name == 'release-proxy' || github.ref_name == 'release-compute'
|
||||
if: github.ref_name == 'release'|| github.ref_name == 'release-proxy'
|
||||
needs: [ tag, promote-images ]
|
||||
uses: ./.github/workflows/_push-to-acr.yml
|
||||
with:
|
||||
@@ -1059,7 +1056,7 @@ jobs:
|
||||
deploy:
|
||||
needs: [ check-permissions, promote-images, tag, build-and-test-locally, trigger-custom-extensions-build-and-wait, push-to-acr-dev, push-to-acr-prod ]
|
||||
# `!failure() && !cancelled()` is required because the workflow depends on the job that can be skipped: `push-to-acr-dev` and `push-to-acr-prod`
|
||||
if: (github.ref_name == 'main' || github.ref_name == 'release' || github.ref_name == 'release-proxy' || github.ref_name == 'release-compute') && !failure() && !cancelled()
|
||||
if: (github.ref_name == 'main' || github.ref_name == 'release' || github.ref_name == 'release-proxy') && !failure() && !cancelled()
|
||||
|
||||
runs-on: [ self-hosted, small ]
|
||||
container: 369495373322.dkr.ecr.eu-central-1.amazonaws.com/ansible:latest
|
||||
@@ -1108,15 +1105,13 @@ jobs:
|
||||
-f deployProxyAuthBroker=true \
|
||||
-f branch=main \
|
||||
-f dockerTag=${{needs.tag.outputs.build-tag}}
|
||||
elif [[ "$GITHUB_REF_NAME" == "release-compute" ]]; then
|
||||
gh workflow --repo neondatabase/infra run deploy-compute-dev.yml --ref main -f dockerTag=${{needs.tag.outputs.build-tag}}
|
||||
else
|
||||
echo "GITHUB_REF_NAME (value '$GITHUB_REF_NAME') is not set to either 'main', 'release', 'release-proxy' or 'release-compute'"
|
||||
echo "GITHUB_REF_NAME (value '$GITHUB_REF_NAME') is not set to either 'main' or 'release'"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
- name: Create git tag
|
||||
if: github.ref_name == 'release' || github.ref_name == 'release-proxy' || github.ref_name == 'release-compute'
|
||||
if: github.ref_name == 'release' || github.ref_name == 'release-proxy'
|
||||
uses: actions/github-script@v7
|
||||
with:
|
||||
# Retry script for 5XX server errors: https://github.com/actions/github-script#retries
|
||||
|
||||
1
.github/workflows/ingest_benchmark.yml
vendored
1
.github/workflows/ingest_benchmark.yml
vendored
@@ -26,7 +26,6 @@ concurrency:
|
||||
jobs:
|
||||
ingest:
|
||||
strategy:
|
||||
fail-fast: false # allow other variants to continue even if one fails
|
||||
matrix:
|
||||
target_project: [new_empty_project, large_existing_project]
|
||||
permissions:
|
||||
|
||||
23
.github/workflows/release.yml
vendored
23
.github/workflows/release.yml
vendored
@@ -15,10 +15,6 @@ on:
|
||||
type: boolean
|
||||
description: 'Create Proxy release PR'
|
||||
required: false
|
||||
create-compute-release-branch:
|
||||
type: boolean
|
||||
description: 'Create Compute release PR'
|
||||
required: false
|
||||
|
||||
# No permission for GITHUB_TOKEN by default; the **minimal required** set of permissions should be granted in each job.
|
||||
permissions: {}
|
||||
@@ -29,20 +25,20 @@ defaults:
|
||||
|
||||
jobs:
|
||||
create-storage-release-branch:
|
||||
if: ${{ github.event.schedule == '0 6 * * MON' || inputs.create-storage-release-branch }}
|
||||
if: ${{ github.event.schedule == '0 6 * * MON' || format('{0}', inputs.create-storage-release-branch) == 'true' }}
|
||||
|
||||
permissions:
|
||||
contents: write
|
||||
|
||||
uses: ./.github/workflows/_create-release-pr.yml
|
||||
with:
|
||||
component-name: 'Storage'
|
||||
component-name: 'Storage & Compute'
|
||||
release-branch: 'release'
|
||||
secrets:
|
||||
ci-access-token: ${{ secrets.CI_ACCESS_TOKEN }}
|
||||
|
||||
create-proxy-release-branch:
|
||||
if: ${{ github.event.schedule == '0 6 * * THU' || inputs.create-proxy-release-branch }}
|
||||
if: ${{ github.event.schedule == '0 6 * * THU' || format('{0}', inputs.create-proxy-release-branch) == 'true' }}
|
||||
|
||||
permissions:
|
||||
contents: write
|
||||
@@ -53,16 +49,3 @@ jobs:
|
||||
release-branch: 'release-proxy'
|
||||
secrets:
|
||||
ci-access-token: ${{ secrets.CI_ACCESS_TOKEN }}
|
||||
|
||||
create-compute-release-branch:
|
||||
if: inputs.create-compute-release-branch
|
||||
|
||||
permissions:
|
||||
contents: write
|
||||
|
||||
uses: ./.github/workflows/_create-release-pr.yml
|
||||
with:
|
||||
component-name: 'Compute'
|
||||
release-branch: 'release-compute'
|
||||
secrets:
|
||||
ci-access-token: ${{ secrets.CI_ACCESS_TOKEN }}
|
||||
|
||||
2
.github/workflows/trigger-e2e-tests.yml
vendored
2
.github/workflows/trigger-e2e-tests.yml
vendored
@@ -51,8 +51,6 @@ jobs:
|
||||
echo "tag=release-$(git rev-list --count HEAD)" | tee -a $GITHUB_OUTPUT
|
||||
elif [[ "$GITHUB_REF_NAME" == "release-proxy" ]]; then
|
||||
echo "tag=release-proxy-$(git rev-list --count HEAD)" >> $GITHUB_OUTPUT
|
||||
elif [[ "$GITHUB_REF_NAME" == "release-compute" ]]; then
|
||||
echo "tag=release-compute-$(git rev-list --count HEAD)" >> $GITHUB_OUTPUT
|
||||
else
|
||||
echo "GITHUB_REF_NAME (value '$GITHUB_REF_NAME') is not set to either 'main' or 'release'"
|
||||
BUILD_AND_TEST_RUN_ID=$(gh run list -b $CURRENT_BRANCH -c $CURRENT_SHA -w 'Build and Test' -L 1 --json databaseId --jq '.[].databaseId')
|
||||
|
||||
@@ -2,7 +2,6 @@
|
||||
/compute_tools/ @neondatabase/control-plane @neondatabase/compute
|
||||
/libs/pageserver_api/ @neondatabase/storage
|
||||
/libs/postgres_ffi/ @neondatabase/compute @neondatabase/storage
|
||||
/libs/proxy/ @neondatabase/proxy
|
||||
/libs/remote_storage/ @neondatabase/storage
|
||||
/libs/safekeeper_api/ @neondatabase/storage
|
||||
/libs/vm_monitor/ @neondatabase/autoscaling
|
||||
|
||||
142
Cargo.lock
generated
142
Cargo.lock
generated
@@ -84,16 +84,16 @@ checksum = "4b46cbb362ab8752921c97e041f5e366ee6297bd428a31275b9fcf1e380f7299"
|
||||
|
||||
[[package]]
|
||||
name = "anstream"
|
||||
version = "0.6.15"
|
||||
version = "0.3.2"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "64e15c1ab1f89faffbf04a634d5e1962e9074f2741eef6d97f3c4e322426d526"
|
||||
checksum = "0ca84f3628370c59db74ee214b3263d58f9aadd9b4fe7e711fd87dc452b7f163"
|
||||
dependencies = [
|
||||
"anstyle",
|
||||
"anstyle-parse",
|
||||
"anstyle-query",
|
||||
"anstyle-wincon",
|
||||
"colorchoice",
|
||||
"is_terminal_polyfill",
|
||||
"is-terminal",
|
||||
"utf8parse",
|
||||
]
|
||||
|
||||
@@ -123,12 +123,12 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "anstyle-wincon"
|
||||
version = "3.0.4"
|
||||
version = "1.0.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "5bf74e1b6e971609db8ca7a9ce79fd5768ab6ae46441c572e46cf596f59e57f8"
|
||||
checksum = "180abfa45703aebe0093f79badacc01b8fd4ea2e35118747e5811127f926e188"
|
||||
dependencies = [
|
||||
"anstyle",
|
||||
"windows-sys 0.52.0",
|
||||
"windows-sys 0.48.0",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
@@ -1031,9 +1031,9 @@ checksum = "14c189c53d098945499cdfa7ecc63567cf3886b3332b312a5b4585d8d3a6a610"
|
||||
|
||||
[[package]]
|
||||
name = "bytes"
|
||||
version = "1.9.0"
|
||||
version = "1.5.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "325918d6fe32f23b19878fe4b34794ae41fc19ddbe53b10571a4874d44ffd39b"
|
||||
checksum = "a2bd12c1caf447e69cd4528f47f94d203fd2582878ecb9e9465484c4148a8223"
|
||||
dependencies = [
|
||||
"serde",
|
||||
]
|
||||
@@ -1167,33 +1167,35 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "clap"
|
||||
version = "4.5.22"
|
||||
version = "4.3.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "69371e34337c4c984bbe322360c2547210bf632eb2814bbe78a6e87a2935bd2b"
|
||||
checksum = "93aae7a4192245f70fe75dd9157fc7b4a5bf53e88d30bd4396f7d8f9284d5acc"
|
||||
dependencies = [
|
||||
"clap_builder",
|
||||
"clap_derive",
|
||||
"once_cell",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "clap_builder"
|
||||
version = "4.5.22"
|
||||
version = "4.3.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "6e24c1b4099818523236a8ca881d2b45db98dadfb4625cf6608c12069fcbbde1"
|
||||
checksum = "4f423e341edefb78c9caba2d9c7f7687d0e72e89df3ce3394554754393ac3990"
|
||||
dependencies = [
|
||||
"anstream",
|
||||
"anstyle",
|
||||
"bitflags 1.3.2",
|
||||
"clap_lex",
|
||||
"strsim 0.11.1",
|
||||
"strsim",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "clap_derive"
|
||||
version = "4.5.18"
|
||||
version = "4.3.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "4ac6a0c7b1a9e9a5186361f67dfa1b88213572f427fb9ab038efb2bd8c582dab"
|
||||
checksum = "191d9573962933b4027f932c600cd252ce27a8ad5979418fe78e43c07996f27b"
|
||||
dependencies = [
|
||||
"heck",
|
||||
"heck 0.4.1",
|
||||
"proc-macro2",
|
||||
"quote",
|
||||
"syn 2.0.90",
|
||||
@@ -1201,9 +1203,9 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "clap_lex"
|
||||
version = "0.7.3"
|
||||
version = "0.5.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "afb84c814227b90d6895e01398aee0d8033c00e7466aca416fb6a8e0eb19d8a7"
|
||||
checksum = "2da6da31387c7e4ef160ffab6d5e7f00c42626fe39aea70a7b0f1773f7dd6c1b"
|
||||
|
||||
[[package]]
|
||||
name = "colorchoice"
|
||||
@@ -1612,7 +1614,7 @@ dependencies = [
|
||||
"ident_case",
|
||||
"proc-macro2",
|
||||
"quote",
|
||||
"strsim 0.10.0",
|
||||
"strsim",
|
||||
"syn 2.0.90",
|
||||
]
|
||||
|
||||
@@ -1810,7 +1812,7 @@ checksum = "0892a17df262a24294c382f0d5997571006e7a4348b4327557c4ff1cd4a8bccc"
|
||||
dependencies = [
|
||||
"darling",
|
||||
"either",
|
||||
"heck",
|
||||
"heck 0.5.0",
|
||||
"proc-macro2",
|
||||
"quote",
|
||||
"syn 2.0.90",
|
||||
@@ -1948,15 +1950,6 @@ dependencies = [
|
||||
"syn 2.0.90",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "env_filter"
|
||||
version = "0.1.2"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "4f2c92ceda6ceec50f43169f9ee8424fe2db276791afde7b2cd8bc084cb376ab"
|
||||
dependencies = [
|
||||
"log",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "env_logger"
|
||||
version = "0.10.2"
|
||||
@@ -1970,18 +1963,6 @@ dependencies = [
|
||||
"termcolor",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "env_logger"
|
||||
version = "0.11.2"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "6c012a26a7f605efc424dd53697843a72be7dc86ad2d01f7814337794a12231d"
|
||||
dependencies = [
|
||||
"anstream",
|
||||
"anstyle",
|
||||
"env_filter",
|
||||
"log",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "equator"
|
||||
version = "0.2.2"
|
||||
@@ -2484,6 +2465,12 @@ dependencies = [
|
||||
"num-traits",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "heck"
|
||||
version = "0.4.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "95505c38b4572b2d910cecb0281560f54b440a19336cbbcb27bf6ce6adc6f5a8"
|
||||
|
||||
[[package]]
|
||||
name = "heck"
|
||||
version = "0.5.0"
|
||||
@@ -2901,12 +2888,6 @@ dependencies = [
|
||||
"windows-sys 0.52.0",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "is_terminal_polyfill"
|
||||
version = "1.70.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "7943c866cc5cd64cbc25b2e01621d07fa8eb2a1a23160ee81ce38704e97b8ecf"
|
||||
|
||||
[[package]]
|
||||
name = "itertools"
|
||||
version = "0.10.5"
|
||||
@@ -3188,7 +3169,7 @@ version = "0.0.22"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "b9e6777fc80a575f9503d908c8b498782a6c3ee88a06cb416dc3941401e43b94"
|
||||
dependencies = [
|
||||
"heck",
|
||||
"heck 0.5.0",
|
||||
"proc-macro2",
|
||||
"quote",
|
||||
"syn 2.0.90",
|
||||
@@ -3389,16 +3370,6 @@ dependencies = [
|
||||
"winapi",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "nu-ansi-term"
|
||||
version = "0.46.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "77a8165726e8236064dbb45459242600304b42a5ea24ee2948e18e023bf7ba84"
|
||||
dependencies = [
|
||||
"overload",
|
||||
"winapi",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "num"
|
||||
version = "0.4.1"
|
||||
@@ -3694,12 +3665,6 @@ version = "0.5.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "4030760ffd992bef45b0ae3f10ce1aba99e33464c90d14dd7c039884963ddc7a"
|
||||
|
||||
[[package]]
|
||||
name = "overload"
|
||||
version = "0.1.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "b15813163c1d831bf4a13c3610c05c0d03b39feb07f7e09fa234dac9b15aaf39"
|
||||
|
||||
[[package]]
|
||||
name = "p256"
|
||||
version = "0.11.1"
|
||||
@@ -4219,7 +4184,7 @@ dependencies = [
|
||||
[[package]]
|
||||
name = "postgres-protocol"
|
||||
version = "0.6.4"
|
||||
source = "git+https://github.com/neondatabase/rust-postgres.git?branch=neon#511f998c00148ab7c847bd7e6cfd3a906d0e7473"
|
||||
source = "git+https://github.com/neondatabase/rust-postgres.git?branch=neon#00940fcdb57a8e99e805297b75839e7c4c7b1796"
|
||||
dependencies = [
|
||||
"base64 0.20.0",
|
||||
"byteorder",
|
||||
@@ -4232,6 +4197,7 @@ dependencies = [
|
||||
"rand 0.8.5",
|
||||
"sha2",
|
||||
"stringprep",
|
||||
"tokio",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
@@ -4243,6 +4209,7 @@ dependencies = [
|
||||
"bytes",
|
||||
"fallible-iterator",
|
||||
"hmac",
|
||||
"md-5",
|
||||
"memchr",
|
||||
"rand 0.8.5",
|
||||
"sha2",
|
||||
@@ -4253,7 +4220,7 @@ dependencies = [
|
||||
[[package]]
|
||||
name = "postgres-types"
|
||||
version = "0.2.4"
|
||||
source = "git+https://github.com/neondatabase/rust-postgres.git?branch=neon#511f998c00148ab7c847bd7e6cfd3a906d0e7473"
|
||||
source = "git+https://github.com/neondatabase/rust-postgres.git?branch=neon#00940fcdb57a8e99e805297b75839e7c4c7b1796"
|
||||
dependencies = [
|
||||
"bytes",
|
||||
"fallible-iterator",
|
||||
@@ -4309,7 +4276,7 @@ dependencies = [
|
||||
"bindgen",
|
||||
"bytes",
|
||||
"crc32c",
|
||||
"env_logger 0.10.2",
|
||||
"env_logger",
|
||||
"log",
|
||||
"memoffset 0.9.0",
|
||||
"once_cell",
|
||||
@@ -4492,7 +4459,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "0c1318b19085f08681016926435853bbf7858f9c082d0999b80550ff5d9abe15"
|
||||
dependencies = [
|
||||
"bytes",
|
||||
"heck",
|
||||
"heck 0.5.0",
|
||||
"itertools 0.12.1",
|
||||
"log",
|
||||
"multimap",
|
||||
@@ -4577,7 +4544,7 @@ dependencies = [
|
||||
"consumption_metrics",
|
||||
"dashmap",
|
||||
"ecdsa 0.16.9",
|
||||
"env_logger 0.10.2",
|
||||
"env_logger",
|
||||
"fallible-iterator",
|
||||
"flate2",
|
||||
"framed-websockets",
|
||||
@@ -4645,7 +4612,6 @@ dependencies = [
|
||||
"tikv-jemalloc-ctl",
|
||||
"tikv-jemallocator",
|
||||
"tokio",
|
||||
"tokio-postgres",
|
||||
"tokio-postgres2",
|
||||
"tokio-rustls 0.26.0",
|
||||
"tokio-tungstenite",
|
||||
@@ -6100,7 +6066,6 @@ dependencies = [
|
||||
"serde_json",
|
||||
"strum",
|
||||
"strum_macros",
|
||||
"test-log",
|
||||
"thiserror",
|
||||
"tokio",
|
||||
"tokio-util",
|
||||
@@ -6201,12 +6166,6 @@ version = "0.10.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "73473c0e59e6d5812c5dfe2a064a6444949f089e20eec9a2e5506596494e4623"
|
||||
|
||||
[[package]]
|
||||
name = "strsim"
|
||||
version = "0.11.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "7da8b5736845d9f2fcb837ea5d9e2628564b3b043a70948a3f0b778838c5fb4f"
|
||||
|
||||
[[package]]
|
||||
name = "strum"
|
||||
version = "0.26.3"
|
||||
@@ -6219,7 +6178,7 @@ version = "0.26.4"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "4c6bee85a5a24955dc440386795aa378cd9cf82acd5f764469152d2270e581be"
|
||||
dependencies = [
|
||||
"heck",
|
||||
"heck 0.5.0",
|
||||
"proc-macro2",
|
||||
"quote",
|
||||
"rustversion",
|
||||
@@ -6387,28 +6346,6 @@ dependencies = [
|
||||
"syn 2.0.90",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "test-log"
|
||||
version = "0.2.16"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "3dffced63c2b5c7be278154d76b479f9f9920ed34e7574201407f0b14e2bbb93"
|
||||
dependencies = [
|
||||
"env_logger 0.11.2",
|
||||
"test-log-macros",
|
||||
"tracing-subscriber",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "test-log-macros"
|
||||
version = "0.2.16"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "5999e24eaa32083191ba4e425deb75cdf25efefabe5aaccb7446dd0d4122a3f5"
|
||||
dependencies = [
|
||||
"proc-macro2",
|
||||
"quote",
|
||||
"syn 2.0.90",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "thiserror"
|
||||
version = "1.0.69"
|
||||
@@ -6606,7 +6543,7 @@ dependencies = [
|
||||
[[package]]
|
||||
name = "tokio-postgres"
|
||||
version = "0.7.7"
|
||||
source = "git+https://github.com/neondatabase/rust-postgres.git?branch=neon#511f998c00148ab7c847bd7e6cfd3a906d0e7473"
|
||||
source = "git+https://github.com/neondatabase/rust-postgres.git?branch=neon#00940fcdb57a8e99e805297b75839e7c4c7b1796"
|
||||
dependencies = [
|
||||
"async-trait",
|
||||
"byteorder",
|
||||
@@ -6960,7 +6897,6 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "e8189decb5ac0fa7bc8b96b7cb9b2701d60d48805aca84a238004d665fcc4008"
|
||||
dependencies = [
|
||||
"matchers",
|
||||
"nu-ansi-term",
|
||||
"once_cell",
|
||||
"regex",
|
||||
"serde",
|
||||
@@ -7269,7 +7205,7 @@ dependencies = [
|
||||
"anyhow",
|
||||
"camino-tempfile",
|
||||
"clap",
|
||||
"env_logger 0.10.2",
|
||||
"env_logger",
|
||||
"log",
|
||||
"postgres",
|
||||
"postgres_ffi",
|
||||
|
||||
@@ -74,7 +74,7 @@ bindgen = "0.70"
|
||||
bit_field = "0.10.2"
|
||||
bstr = "1.0"
|
||||
byteorder = "1.4"
|
||||
bytes = "1.9"
|
||||
bytes = "1.0"
|
||||
camino = "1.1.6"
|
||||
cfg-if = "1.0.0"
|
||||
chrono = { version = "0.4", default-features = false, features = ["clock"] }
|
||||
|
||||
@@ -1,9 +1,5 @@
|
||||
[databases]
|
||||
;; pgbouncer propagates application_name (if it's specified) to the server, but some
|
||||
;; clients don't set it. We set default application_name=pgbouncer to make it
|
||||
;; easier to identify pgbouncer connections in Postgres. If client sets
|
||||
;; application_name, it will be used instead.
|
||||
*=host=localhost port=5432 auth_user=cloud_admin application_name=pgbouncer
|
||||
*=host=localhost port=5432 auth_user=cloud_admin
|
||||
[pgbouncer]
|
||||
listen_port=6432
|
||||
listen_addr=0.0.0.0
|
||||
|
||||
@@ -310,6 +310,41 @@ async fn routes(req: Request<Body>, compute: &Arc<ComputeNode>) -> Response<Body
|
||||
}
|
||||
}
|
||||
|
||||
// handle HEAD method of /installed_extensions route
|
||||
// This is the signal from postgres that extension DDL occured and we need to update the metric.
|
||||
//
|
||||
// Don't wait for the result, because the caller doesn't need it
|
||||
// just spawn a task and the metric will be updated eventually.
|
||||
//
|
||||
// In theory there could be multiple HEAD requests in a row, so we should protect
|
||||
// from spawning multiple tasks, but in practice it's not a problem.
|
||||
// TODO: add some mutex or quere?
|
||||
// In practice, extensions are not installed very often, so it's not a problem
|
||||
(&Method::HEAD, route) if route.starts_with("/installed_extensions") => {
|
||||
info!("serving /installed_extensions HEAD request");
|
||||
let status = compute.get_status();
|
||||
if status != ComputeStatus::Running {
|
||||
let msg = format!(
|
||||
"invalid compute status for extensions request: {:?}",
|
||||
status
|
||||
);
|
||||
error!(msg);
|
||||
let mut resp = Response::new(Body::from(msg));
|
||||
*resp.status_mut() = StatusCode::INTERNAL_SERVER_ERROR;
|
||||
return resp;
|
||||
}
|
||||
|
||||
let conf = compute.get_conn_conf(None);
|
||||
task::spawn(async move {
|
||||
let _ = task::spawn_blocking(move || {
|
||||
installed_extensions::get_installed_extensions(conf)
|
||||
})
|
||||
.await;
|
||||
});
|
||||
|
||||
Response::new(Body::from("OK"))
|
||||
}
|
||||
|
||||
// download extension files from remote extension storage on demand
|
||||
(&Method::POST, route) if route.starts_with("/extension_server/") => {
|
||||
info!("serving {:?} POST request", route);
|
||||
|
||||
@@ -82,6 +82,21 @@ paths:
|
||||
application/json:
|
||||
schema:
|
||||
$ref: "#/components/schemas/InstalledExtensions"
|
||||
head:
|
||||
tags:
|
||||
- Extension
|
||||
summary: Report extension DDL to trigger metric recollection.
|
||||
description: ""
|
||||
operationId: ExtensionDDL
|
||||
responses:
|
||||
200:
|
||||
description: Result
|
||||
500:
|
||||
description: Internal error
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
$ref: "#/components/schemas/GenericError"
|
||||
/info:
|
||||
get:
|
||||
tags:
|
||||
@@ -537,12 +552,14 @@ components:
|
||||
properties:
|
||||
extname:
|
||||
type: string
|
||||
versions:
|
||||
type: array
|
||||
version:
|
||||
type: string
|
||||
items:
|
||||
type: string
|
||||
n_databases:
|
||||
type: integer
|
||||
owned_by_superuser:
|
||||
type: integer
|
||||
|
||||
SetRoleGrantsRequest:
|
||||
type: object
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
use compute_api::responses::{InstalledExtension, InstalledExtensions};
|
||||
use metrics::proto::MetricFamily;
|
||||
use std::collections::HashMap;
|
||||
use std::collections::HashSet;
|
||||
|
||||
use anyhow::Result;
|
||||
use postgres::{Client, NoTls};
|
||||
@@ -38,65 +37,94 @@ fn list_dbs(client: &mut Client) -> Result<Vec<String>> {
|
||||
/// Connect to every database (see list_dbs above) and get the list of installed extensions.
|
||||
///
|
||||
/// Same extension can be installed in multiple databases with different versions,
|
||||
/// we only keep the highest and lowest version across all databases.
|
||||
/// so we report a separate metric (number of databases where it is installed)
|
||||
/// for each extension version.
|
||||
pub fn get_installed_extensions(mut conf: postgres::config::Config) -> Result<InstalledExtensions> {
|
||||
conf.application_name("compute_ctl:get_installed_extensions");
|
||||
let mut client = conf.connect(NoTls)?;
|
||||
|
||||
let databases: Vec<String> = list_dbs(&mut client)?;
|
||||
|
||||
let mut extensions_map: HashMap<String, InstalledExtension> = HashMap::new();
|
||||
let mut extensions_map: HashMap<(String, String, String), InstalledExtension> = HashMap::new();
|
||||
for db in databases.iter() {
|
||||
conf.dbname(db);
|
||||
let mut db_client = conf.connect(NoTls)?;
|
||||
let extensions: Vec<(String, String)> = db_client
|
||||
let extensions: Vec<(String, String, i32)> = db_client
|
||||
.query(
|
||||
"SELECT extname, extversion FROM pg_catalog.pg_extension;",
|
||||
"SELECT extname, extversion, extowner::integer FROM pg_catalog.pg_extension",
|
||||
&[],
|
||||
)?
|
||||
.iter()
|
||||
.map(|row| (row.get("extname"), row.get("extversion")))
|
||||
.map(|row| {
|
||||
(
|
||||
row.get("extname"),
|
||||
row.get("extversion"),
|
||||
row.get("extowner"),
|
||||
)
|
||||
})
|
||||
.collect();
|
||||
|
||||
for (extname, v) in extensions.iter() {
|
||||
for (extname, v, extowner) in extensions.iter() {
|
||||
let version = v.to_string();
|
||||
|
||||
// increment the number of databases where the version of extension is installed
|
||||
INSTALLED_EXTENSIONS
|
||||
.with_label_values(&[extname, &version])
|
||||
.inc();
|
||||
// check if the extension is owned by superuser
|
||||
// 10 is the oid of superuser
|
||||
let owned_by_superuser = if *extowner == 10 { "1" } else { "0" };
|
||||
|
||||
extensions_map
|
||||
.entry(extname.to_string())
|
||||
.entry((
|
||||
extname.to_string(),
|
||||
version.clone(),
|
||||
owned_by_superuser.to_string(),
|
||||
))
|
||||
.and_modify(|e| {
|
||||
e.versions.insert(version.clone());
|
||||
// count the number of databases where the extension is installed
|
||||
e.n_databases += 1;
|
||||
})
|
||||
.or_insert(InstalledExtension {
|
||||
extname: extname.to_string(),
|
||||
versions: HashSet::from([version.clone()]),
|
||||
version: version.clone(),
|
||||
n_databases: 1,
|
||||
owned_by_superuser: owned_by_superuser.to_string(),
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
let res = InstalledExtensions {
|
||||
extensions: extensions_map.into_values().collect(),
|
||||
};
|
||||
// reset the metric to handle dropped extensions and extension version changes -
|
||||
// we need to remove them from the metric.
|
||||
// It creates a race condition - if collector is called before we set the new values
|
||||
// we will have a gap in the metrics.
|
||||
//
|
||||
// TODO: Add a mutex to lock the metric update and collection
|
||||
// so that the collector doesn't see intermediate state.
|
||||
// Or is it ok for this metric to be eventually consistent?
|
||||
INSTALLED_EXTENSIONS.reset();
|
||||
|
||||
Ok(res)
|
||||
for (key, ext) in extensions_map.iter() {
|
||||
let (extname, version, owned_by_superuser) = key;
|
||||
let n_databases = ext.n_databases as u64;
|
||||
|
||||
INSTALLED_EXTENSIONS
|
||||
.with_label_values(&[extname, version, owned_by_superuser])
|
||||
.set(n_databases);
|
||||
}
|
||||
|
||||
tracing::debug!("Installed extensions: {:?}", extensions_map);
|
||||
|
||||
Ok(InstalledExtensions {
|
||||
extensions: extensions_map.into_values().collect(),
|
||||
})
|
||||
}
|
||||
|
||||
static INSTALLED_EXTENSIONS: Lazy<UIntGaugeVec> = Lazy::new(|| {
|
||||
register_uint_gauge_vec!(
|
||||
"compute_installed_extensions",
|
||||
"Number of databases where the version of extension is installed",
|
||||
&["extension_name", "version"]
|
||||
&["extension_name", "version", "owned_by_superuser"]
|
||||
)
|
||||
.expect("failed to define a metric")
|
||||
});
|
||||
|
||||
pub fn collect() -> Vec<MetricFamily> {
|
||||
// TODO Add a mutex here to ensure that the metric is not updated while we are collecting it
|
||||
INSTALLED_EXTENSIONS.collect()
|
||||
}
|
||||
|
||||
@@ -5,7 +5,6 @@
|
||||
//! ```text
|
||||
//! .neon/safekeepers/<safekeeper id>
|
||||
//! ```
|
||||
use std::error::Error as _;
|
||||
use std::future::Future;
|
||||
use std::io::Write;
|
||||
use std::path::PathBuf;
|
||||
@@ -27,7 +26,7 @@ use crate::{
|
||||
|
||||
#[derive(Error, Debug)]
|
||||
pub enum SafekeeperHttpError {
|
||||
#[error("request error: {0}{}", .0.source().map(|e| format!(": {e}")).unwrap_or_default())]
|
||||
#[error("Reqwest error: {0}")]
|
||||
Transport(#[from] reqwest::Error),
|
||||
|
||||
#[error("Error: {0}")]
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
//! Structs representing the JSON formats used in the compute_ctl's HTTP API.
|
||||
|
||||
use std::collections::HashSet;
|
||||
use std::fmt::Display;
|
||||
|
||||
use chrono::{DateTime, Utc};
|
||||
@@ -163,8 +162,9 @@ pub enum ControlPlaneComputeStatus {
|
||||
#[derive(Clone, Debug, Default, Serialize)]
|
||||
pub struct InstalledExtension {
|
||||
pub extname: String,
|
||||
pub versions: HashSet<String>,
|
||||
pub version: String,
|
||||
pub n_databases: u32, // Number of databases using this extension
|
||||
pub owned_by_superuser: String,
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, Default, Serialize)]
|
||||
|
||||
@@ -442,14 +442,7 @@ impl Default for ConfigToml {
|
||||
tenant_config: TenantConfigToml::default(),
|
||||
no_sync: None,
|
||||
wal_receiver_protocol: DEFAULT_WAL_RECEIVER_PROTOCOL,
|
||||
page_service_pipelining: if !cfg!(test) {
|
||||
PageServicePipeliningConfig::Serial
|
||||
} else {
|
||||
PageServicePipeliningConfig::Pipelined(PageServicePipeliningConfigPipelined {
|
||||
max_batch_size: NonZeroUsize::new(32).unwrap(),
|
||||
execution: PageServiceProtocolPipelinedExecutionStrategy::ConcurrentFutures,
|
||||
})
|
||||
},
|
||||
page_service_pipelining: PageServicePipeliningConfig::Serial,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -48,7 +48,7 @@ pub struct TenantCreateResponse {
|
||||
pub shards: Vec<TenantCreateResponseShard>,
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize, Debug, Clone)]
|
||||
#[derive(Serialize, Deserialize)]
|
||||
pub struct NodeRegisterRequest {
|
||||
pub node_id: NodeId,
|
||||
|
||||
@@ -75,7 +75,7 @@ pub struct TenantPolicyRequest {
|
||||
pub scheduling: Option<ShardSchedulingPolicy>,
|
||||
}
|
||||
|
||||
#[derive(Clone, Serialize, Deserialize, PartialEq, Eq, Hash, Debug)]
|
||||
#[derive(Clone, Serialize, Deserialize, PartialEq, Eq, Hash)]
|
||||
pub struct AvailabilityZone(pub String);
|
||||
|
||||
impl Display for AvailabilityZone {
|
||||
@@ -325,16 +325,6 @@ pub enum PlacementPolicy {
|
||||
Detached,
|
||||
}
|
||||
|
||||
impl PlacementPolicy {
|
||||
pub fn want_secondaries(&self) -> usize {
|
||||
match self {
|
||||
PlacementPolicy::Attached(secondary_count) => *secondary_count,
|
||||
PlacementPolicy::Secondary => 1,
|
||||
PlacementPolicy::Detached => 0,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize, Debug)]
|
||||
pub struct TenantShardMigrateResponse {}
|
||||
|
||||
|
||||
@@ -770,11 +770,6 @@ impl Key {
|
||||
&& self.field6 == 1
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
pub fn is_aux_file_key(&self) -> bool {
|
||||
self.field1 == AUX_KEY_PREFIX
|
||||
}
|
||||
|
||||
/// Guaranteed to return `Ok()` if [`Self::is_rel_block_key`] returns `true` for `key`.
|
||||
#[inline(always)]
|
||||
pub fn to_rel_block(self) -> anyhow::Result<(RelTag, BlockNumber)> {
|
||||
|
||||
@@ -501,9 +501,7 @@ pub struct EvictionPolicyLayerAccessThreshold {
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize, Clone, PartialEq, Eq)]
|
||||
pub struct ThrottleConfig {
|
||||
/// See [`ThrottleConfigTaskKinds`] for why we do the serde `rename`.
|
||||
#[serde(rename = "task_kinds")]
|
||||
pub enabled: ThrottleConfigTaskKinds,
|
||||
pub task_kinds: Vec<String>, // TaskKind
|
||||
pub initial: u32,
|
||||
#[serde(with = "humantime_serde")]
|
||||
pub refill_interval: Duration,
|
||||
@@ -511,38 +509,10 @@ pub struct ThrottleConfig {
|
||||
pub max: u32,
|
||||
}
|
||||
|
||||
/// Before <https://github.com/neondatabase/neon/pull/9962>
|
||||
/// the throttle was a per `Timeline::get`/`Timeline::get_vectored` call.
|
||||
/// The `task_kinds` field controlled which Pageserver "Task Kind"s
|
||||
/// were subject to the throttle.
|
||||
///
|
||||
/// After that PR, the throttle is applied at pagestream request level
|
||||
/// and the `task_kinds` field does not apply since the only task kind
|
||||
/// that us subject to the throttle is that of the page service.
|
||||
///
|
||||
/// However, we don't want to make a breaking config change right now
|
||||
/// because it means we have to migrate all the tenant configs.
|
||||
/// This will be done in a future PR.
|
||||
///
|
||||
/// In the meantime, we use emptiness / non-emptsiness of the `task_kinds`
|
||||
/// field to determine if the throttle is enabled or not.
|
||||
#[derive(Serialize, Deserialize, Debug, Clone, PartialEq, Eq)]
|
||||
#[serde(transparent)]
|
||||
pub struct ThrottleConfigTaskKinds(Vec<String>);
|
||||
|
||||
impl ThrottleConfigTaskKinds {
|
||||
pub fn disabled() -> Self {
|
||||
Self(vec![])
|
||||
}
|
||||
pub fn is_enabled(&self) -> bool {
|
||||
!self.0.is_empty()
|
||||
}
|
||||
}
|
||||
|
||||
impl ThrottleConfig {
|
||||
pub fn disabled() -> Self {
|
||||
Self {
|
||||
enabled: ThrottleConfigTaskKinds::disabled(),
|
||||
task_kinds: vec![], // effectively disables the throttle
|
||||
// other values don't matter with emtpy `task_kinds`.
|
||||
initial: 0,
|
||||
refill_interval: Duration::from_millis(1),
|
||||
@@ -556,30 +526,6 @@ impl ThrottleConfig {
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod throttle_config_tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_disabled_is_disabled() {
|
||||
let config = ThrottleConfig::disabled();
|
||||
assert!(!config.enabled.is_enabled());
|
||||
}
|
||||
#[test]
|
||||
fn test_enabled_backwards_compat() {
|
||||
let input = serde_json::json!({
|
||||
"task_kinds": ["PageRequestHandler"],
|
||||
"initial": 40000,
|
||||
"refill_interval": "50ms",
|
||||
"refill_amount": 1000,
|
||||
"max": 40000,
|
||||
"fair": true
|
||||
});
|
||||
let config: ThrottleConfig = serde_json::from_value(input).unwrap();
|
||||
assert!(config.enabled.is_enabled());
|
||||
}
|
||||
}
|
||||
|
||||
/// A flattened analog of a `pagesever::tenant::LocationMode`, which
|
||||
/// lists out all possible states (and the virtual "Detached" state)
|
||||
/// in a flat form rather than using rust-style enums.
|
||||
|
||||
@@ -170,37 +170,19 @@ impl ShardIdentity {
|
||||
}
|
||||
}
|
||||
|
||||
/// Return true if the key should be stored on all shards, not just one.
|
||||
fn is_key_global(&self, key: &Key) -> bool {
|
||||
if key.is_slru_block_key() || key.is_slru_segment_size_key() || key.is_aux_file_key() {
|
||||
// Special keys that are only stored on shard 0
|
||||
false
|
||||
} else if key.is_rel_block_key() {
|
||||
// Ordinary relation blocks are distributed across shards
|
||||
false
|
||||
} else if key.is_rel_size_key() {
|
||||
// All shards maintain rel size keys (although only shard 0 is responsible for
|
||||
// keeping it strictly accurate, other shards just reflect the highest block they've ingested)
|
||||
true
|
||||
} else {
|
||||
// For everything else, we assume it must be kept everywhere, because ingest code
|
||||
// might assume this -- this covers functionality where the ingest code has
|
||||
// not (yet) been made fully shard aware.
|
||||
true
|
||||
}
|
||||
}
|
||||
|
||||
/// Return true if the key should be discarded if found in this shard's
|
||||
/// data store, e.g. during compaction after a split.
|
||||
///
|
||||
/// Shards _may_ drop keys which return false here, but are not obliged to.
|
||||
pub fn is_key_disposable(&self, key: &Key) -> bool {
|
||||
if self.count < ShardCount(2) {
|
||||
// Fast path: unsharded tenant doesn't dispose of anything
|
||||
return false;
|
||||
}
|
||||
|
||||
if self.is_key_global(key) {
|
||||
if key_is_shard0(key) {
|
||||
// Q: Why can't we dispose of shard0 content if we're not shard 0?
|
||||
// A1: because the WAL ingestion logic currently ingests some shard 0
|
||||
// content on all shards, even though it's only read on shard 0. If we
|
||||
// dropped it, then subsequent WAL ingest to these keys would encounter
|
||||
// an error.
|
||||
// A2: because key_is_shard0 also covers relation size keys, which are written
|
||||
// on all shards even though they're only maintained accurately on shard 0.
|
||||
false
|
||||
} else {
|
||||
!self.is_key_local(key)
|
||||
|
||||
@@ -100,7 +100,7 @@ impl StartupMessageParamsBuilder {
|
||||
|
||||
#[derive(Debug, Clone, Default)]
|
||||
pub struct StartupMessageParams {
|
||||
pub params: Bytes,
|
||||
params: Bytes,
|
||||
}
|
||||
|
||||
impl StartupMessageParams {
|
||||
|
||||
@@ -10,6 +10,7 @@ byteorder.workspace = true
|
||||
bytes.workspace = true
|
||||
fallible-iterator.workspace = true
|
||||
hmac.workspace = true
|
||||
md-5 = "0.10"
|
||||
memchr = "2.0"
|
||||
rand.workspace = true
|
||||
sha2.workspace = true
|
||||
|
||||
@@ -1,2 +1,37 @@
|
||||
//! Authentication protocol support.
|
||||
use md5::{Digest, Md5};
|
||||
|
||||
pub mod sasl;
|
||||
|
||||
/// Hashes authentication information in a way suitable for use in response
|
||||
/// to an `AuthenticationMd5Password` message.
|
||||
///
|
||||
/// The resulting string should be sent back to the database in a
|
||||
/// `PasswordMessage` message.
|
||||
#[inline]
|
||||
pub fn md5_hash(username: &[u8], password: &[u8], salt: [u8; 4]) -> String {
|
||||
let mut md5 = Md5::new();
|
||||
md5.update(password);
|
||||
md5.update(username);
|
||||
let output = md5.finalize_reset();
|
||||
md5.update(format!("{:x}", output));
|
||||
md5.update(salt);
|
||||
format!("md5{:x}", md5.finalize())
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod test {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn md5() {
|
||||
let username = b"md5_user";
|
||||
let password = b"password";
|
||||
let salt = [0x2a, 0x3d, 0x8f, 0xe0];
|
||||
|
||||
assert_eq!(
|
||||
md5_hash(username, password, salt),
|
||||
"md562af4dd09bbb41884907a838a3233294"
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -117,7 +117,7 @@ enum Credentials<const N: usize> {
|
||||
/// A regular password as a vector of bytes.
|
||||
Password(Vec<u8>),
|
||||
/// A precomputed pair of keys.
|
||||
Keys(ScramKeys<N>),
|
||||
Keys(Box<ScramKeys<N>>),
|
||||
}
|
||||
|
||||
enum State {
|
||||
@@ -176,7 +176,7 @@ impl ScramSha256 {
|
||||
|
||||
/// Constructs a new instance which will use the provided key pair for authentication.
|
||||
pub fn new_with_keys(keys: ScramKeys<32>, channel_binding: ChannelBinding) -> ScramSha256 {
|
||||
let password = Credentials::Keys(keys);
|
||||
let password = Credentials::Keys(keys.into());
|
||||
ScramSha256::new_inner(password, channel_binding, nonce())
|
||||
}
|
||||
|
||||
|
||||
@@ -79,7 +79,7 @@ pub enum Message {
|
||||
AuthenticationCleartextPassword,
|
||||
AuthenticationGss,
|
||||
AuthenticationKerberosV5,
|
||||
AuthenticationMd5Password,
|
||||
AuthenticationMd5Password(AuthenticationMd5PasswordBody),
|
||||
AuthenticationOk,
|
||||
AuthenticationScmCredential,
|
||||
AuthenticationSspi,
|
||||
@@ -191,7 +191,11 @@ impl Message {
|
||||
0 => Message::AuthenticationOk,
|
||||
2 => Message::AuthenticationKerberosV5,
|
||||
3 => Message::AuthenticationCleartextPassword,
|
||||
5 => Message::AuthenticationMd5Password,
|
||||
5 => {
|
||||
let mut salt = [0; 4];
|
||||
buf.read_exact(&mut salt)?;
|
||||
Message::AuthenticationMd5Password(AuthenticationMd5PasswordBody { salt })
|
||||
}
|
||||
6 => Message::AuthenticationScmCredential,
|
||||
7 => Message::AuthenticationGss,
|
||||
8 => Message::AuthenticationGssContinue,
|
||||
|
||||
@@ -255,34 +255,22 @@ pub fn ssl_request(buf: &mut BytesMut) {
|
||||
}
|
||||
|
||||
#[inline]
|
||||
pub fn startup_message(parameters: &StartupMessageParams, buf: &mut BytesMut) -> io::Result<()> {
|
||||
pub fn startup_message<'a, I>(parameters: I, buf: &mut BytesMut) -> io::Result<()>
|
||||
where
|
||||
I: IntoIterator<Item = (&'a str, &'a str)>,
|
||||
{
|
||||
write_body(buf, |buf| {
|
||||
// postgres protocol version 3.0(196608) in bigger-endian
|
||||
buf.put_i32(0x00_03_00_00);
|
||||
buf.put_slice(¶meters.params);
|
||||
for (key, value) in parameters {
|
||||
write_cstr(key.as_bytes(), buf)?;
|
||||
write_cstr(value.as_bytes(), buf)?;
|
||||
}
|
||||
buf.put_u8(0);
|
||||
Ok(())
|
||||
})
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Default, PartialEq, Eq)]
|
||||
pub struct StartupMessageParams {
|
||||
pub params: BytesMut,
|
||||
}
|
||||
|
||||
impl StartupMessageParams {
|
||||
/// Set parameter's value by its name.
|
||||
pub fn insert(&mut self, name: &str, value: &str) {
|
||||
if name.contains('\0') || value.contains('\0') {
|
||||
panic!("startup parameter name or value contained a null")
|
||||
}
|
||||
self.params.put_slice(name.as_bytes());
|
||||
self.params.put_u8(0);
|
||||
self.params.put_slice(value.as_bytes());
|
||||
self.params.put_u8(0);
|
||||
}
|
||||
}
|
||||
|
||||
#[inline]
|
||||
pub fn sync(buf: &mut BytesMut) {
|
||||
buf.put_u8(b'S');
|
||||
|
||||
@@ -8,6 +8,7 @@
|
||||
|
||||
use crate::authentication::sasl;
|
||||
use hmac::{Hmac, Mac};
|
||||
use md5::Md5;
|
||||
use rand::RngCore;
|
||||
use sha2::digest::FixedOutput;
|
||||
use sha2::{Digest, Sha256};
|
||||
@@ -87,3 +88,20 @@ pub(crate) async fn scram_sha_256_salt(
|
||||
base64::encode(server_key)
|
||||
)
|
||||
}
|
||||
|
||||
/// **Not recommended, as MD5 is not considered to be secure.**
|
||||
///
|
||||
/// Hash password using MD5 with the username as the salt.
|
||||
///
|
||||
/// The client may assume the returned string doesn't contain any
|
||||
/// special characters that would require escaping.
|
||||
pub fn md5(password: &[u8], username: &str) -> String {
|
||||
// salt password with username
|
||||
let mut salted_password = Vec::from(password);
|
||||
salted_password.extend_from_slice(username.as_bytes());
|
||||
|
||||
let mut hash = Md5::new();
|
||||
hash.update(&salted_password);
|
||||
let digest = hash.finalize();
|
||||
format!("md5{:x}", digest)
|
||||
}
|
||||
|
||||
@@ -9,3 +9,11 @@ async fn test_encrypt_scram_sha_256() {
|
||||
"SCRAM-SHA-256$4096:AQIDBAUGBwgJCgsMDQ4PEA==$8rrDg00OqaiWXJ7p+sCgHEIaBSHY89ZJl3mfIsf32oY=:05L1f+yZbiN8O0AnO40Og85NNRhvzTS57naKRWCcsIA="
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_encrypt_md5() {
|
||||
assert_eq!(
|
||||
password::md5(b"secret", "foo"),
|
||||
"md54ab2c5d00339c4b2a4e921d2dc4edec7"
|
||||
);
|
||||
}
|
||||
|
||||
@@ -35,7 +35,9 @@ impl FallibleIterator for BackendMessages {
|
||||
}
|
||||
}
|
||||
|
||||
pub struct PostgresCodec;
|
||||
pub struct PostgresCodec {
|
||||
pub max_message_size: Option<usize>,
|
||||
}
|
||||
|
||||
impl Encoder<FrontendMessage> for PostgresCodec {
|
||||
type Error = io::Error;
|
||||
@@ -64,6 +66,15 @@ impl Decoder for PostgresCodec {
|
||||
break;
|
||||
}
|
||||
|
||||
if let Some(max) = self.max_message_size {
|
||||
if len > max {
|
||||
return Err(io::Error::new(
|
||||
io::ErrorKind::InvalidInput,
|
||||
"message too large",
|
||||
));
|
||||
}
|
||||
}
|
||||
|
||||
match header.tag() {
|
||||
backend::NOTICE_RESPONSE_TAG
|
||||
| backend::NOTIFICATION_RESPONSE_TAG
|
||||
|
||||
@@ -6,15 +6,26 @@ use crate::connect_raw::RawConnection;
|
||||
use crate::tls::MakeTlsConnect;
|
||||
use crate::tls::TlsConnect;
|
||||
use crate::{Client, Connection, Error};
|
||||
use postgres_protocol2::message::frontend::StartupMessageParams;
|
||||
use std::fmt;
|
||||
use std::borrow::Cow;
|
||||
use std::str;
|
||||
use std::str::FromStr;
|
||||
use std::time::Duration;
|
||||
use std::{error, fmt, iter, mem};
|
||||
use tokio::io::{AsyncRead, AsyncWrite};
|
||||
|
||||
pub use postgres_protocol2::authentication::sasl::ScramKeys;
|
||||
use tokio::net::TcpStream;
|
||||
|
||||
/// Properties required of a session.
|
||||
#[derive(Debug, Copy, Clone, PartialEq, Eq)]
|
||||
#[non_exhaustive]
|
||||
pub enum TargetSessionAttrs {
|
||||
/// No special properties are required.
|
||||
Any,
|
||||
/// The session must allow writes.
|
||||
ReadWrite,
|
||||
}
|
||||
|
||||
/// TLS configuration.
|
||||
#[derive(Debug, Copy, Clone, PartialEq, Eq)]
|
||||
#[non_exhaustive]
|
||||
@@ -64,37 +75,119 @@ pub enum AuthKeys {
|
||||
}
|
||||
|
||||
/// Connection configuration.
|
||||
///
|
||||
/// Configuration can be parsed from libpq-style connection strings. These strings come in two formats:
|
||||
///
|
||||
/// # Key-Value
|
||||
///
|
||||
/// This format consists of space-separated key-value pairs. Values which are either the empty string or contain
|
||||
/// whitespace should be wrapped in `'`. `'` and `\` characters should be backslash-escaped.
|
||||
///
|
||||
/// ## Keys
|
||||
///
|
||||
/// * `user` - The username to authenticate with. Required.
|
||||
/// * `password` - The password to authenticate with.
|
||||
/// * `dbname` - The name of the database to connect to. Defaults to the username.
|
||||
/// * `options` - Command line options used to configure the server.
|
||||
/// * `application_name` - Sets the `application_name` parameter on the server.
|
||||
/// * `sslmode` - Controls usage of TLS. If set to `disable`, TLS will not be used. If set to `prefer`, TLS will be used
|
||||
/// if available, but not used otherwise. If set to `require`, TLS will be forced to be used. Defaults to `prefer`.
|
||||
/// * `host` - The host to connect to. On Unix platforms, if the host starts with a `/` character it is treated as the
|
||||
/// path to the directory containing Unix domain sockets. Otherwise, it is treated as a hostname. Multiple hosts
|
||||
/// can be specified, separated by commas. Each host will be tried in turn when connecting. Required if connecting
|
||||
/// with the `connect` method.
|
||||
/// * `port` - The port to connect to. Multiple ports can be specified, separated by commas. The number of ports must be
|
||||
/// either 1, in which case it will be used for all hosts, or the same as the number of hosts. Defaults to 5432 if
|
||||
/// omitted or the empty string.
|
||||
/// * `connect_timeout` - The time limit in seconds applied to each socket-level connection attempt. Note that hostnames
|
||||
/// can resolve to multiple IP addresses, and this limit is applied to each address. Defaults to no timeout.
|
||||
/// * `target_session_attrs` - Specifies requirements of the session. If set to `read-write`, the client will check that
|
||||
/// the `transaction_read_write` session parameter is set to `on`. This can be used to connect to the primary server
|
||||
/// in a database cluster as opposed to the secondary read-only mirrors. Defaults to `all`.
|
||||
/// * `channel_binding` - Controls usage of channel binding in the authentication process. If set to `disable`, channel
|
||||
/// binding will not be used. If set to `prefer`, channel binding will be used if available, but not used otherwise.
|
||||
/// If set to `require`, the authentication process will fail if channel binding is not used. Defaults to `prefer`.
|
||||
///
|
||||
/// ## Examples
|
||||
///
|
||||
/// ```not_rust
|
||||
/// host=localhost user=postgres connect_timeout=10 keepalives=0
|
||||
/// ```
|
||||
///
|
||||
/// ```not_rust
|
||||
/// host=/var/lib/postgresql,localhost port=1234 user=postgres password='password with spaces'
|
||||
/// ```
|
||||
///
|
||||
/// ```not_rust
|
||||
/// host=host1,host2,host3 port=1234,,5678 user=postgres target_session_attrs=read-write
|
||||
/// ```
|
||||
///
|
||||
/// # Url
|
||||
///
|
||||
/// This format resembles a URL with a scheme of either `postgres://` or `postgresql://`. All components are optional,
|
||||
/// and the format accepts query parameters for all of the key-value pairs described in the section above. Multiple
|
||||
/// host/port pairs can be comma-separated. Unix socket paths in the host section of the URL should be percent-encoded,
|
||||
/// as the path component of the URL specifies the database name.
|
||||
///
|
||||
/// ## Examples
|
||||
///
|
||||
/// ```not_rust
|
||||
/// postgresql://user@localhost
|
||||
/// ```
|
||||
///
|
||||
/// ```not_rust
|
||||
/// postgresql://user:password@%2Fvar%2Flib%2Fpostgresql/mydb?connect_timeout=10
|
||||
/// ```
|
||||
///
|
||||
/// ```not_rust
|
||||
/// postgresql://user@host1:1234,host2,host3:5678?target_session_attrs=read-write
|
||||
/// ```
|
||||
///
|
||||
/// ```not_rust
|
||||
/// postgresql:///mydb?user=user&host=/var/lib/postgresql
|
||||
/// ```
|
||||
#[derive(Clone, PartialEq, Eq)]
|
||||
pub struct Config {
|
||||
pub(crate) host: Host,
|
||||
pub(crate) port: u16,
|
||||
|
||||
pub(crate) user: Option<String>,
|
||||
pub(crate) password: Option<Vec<u8>>,
|
||||
pub(crate) auth_keys: Option<Box<AuthKeys>>,
|
||||
pub(crate) dbname: Option<String>,
|
||||
pub(crate) options: Option<String>,
|
||||
pub(crate) application_name: Option<String>,
|
||||
pub(crate) ssl_mode: SslMode,
|
||||
pub(crate) host: Vec<Host>,
|
||||
pub(crate) port: Vec<u16>,
|
||||
pub(crate) connect_timeout: Option<Duration>,
|
||||
pub(crate) target_session_attrs: TargetSessionAttrs,
|
||||
pub(crate) channel_binding: ChannelBinding,
|
||||
pub(crate) server_params: StartupMessageParams,
|
||||
pub(crate) replication_mode: Option<ReplicationMode>,
|
||||
pub(crate) max_backend_message_size: Option<usize>,
|
||||
}
|
||||
|
||||
database: bool,
|
||||
username: bool,
|
||||
impl Default for Config {
|
||||
fn default() -> Config {
|
||||
Config::new()
|
||||
}
|
||||
}
|
||||
|
||||
impl Config {
|
||||
/// Creates a new configuration.
|
||||
pub fn new(host: String, port: u16) -> Config {
|
||||
pub fn new() -> Config {
|
||||
Config {
|
||||
host: Host::Tcp(host),
|
||||
port,
|
||||
user: None,
|
||||
password: None,
|
||||
auth_keys: None,
|
||||
dbname: None,
|
||||
options: None,
|
||||
application_name: None,
|
||||
ssl_mode: SslMode::Prefer,
|
||||
host: vec![],
|
||||
port: vec![],
|
||||
connect_timeout: None,
|
||||
target_session_attrs: TargetSessionAttrs::Any,
|
||||
channel_binding: ChannelBinding::Prefer,
|
||||
server_params: StartupMessageParams::default(),
|
||||
|
||||
database: false,
|
||||
username: false,
|
||||
replication_mode: None,
|
||||
max_backend_message_size: None,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -102,13 +195,14 @@ impl Config {
|
||||
///
|
||||
/// Required.
|
||||
pub fn user(&mut self, user: &str) -> &mut Config {
|
||||
self.set_param("user", user)
|
||||
self.user = Some(user.to_string());
|
||||
self
|
||||
}
|
||||
|
||||
/// Gets the user to authenticate with, if one has been configured with
|
||||
/// the `user` method.
|
||||
pub fn user_is_set(&self) -> bool {
|
||||
self.username
|
||||
pub fn get_user(&self) -> Option<&str> {
|
||||
self.user.as_deref()
|
||||
}
|
||||
|
||||
/// Sets the password to authenticate with.
|
||||
@@ -144,26 +238,40 @@ impl Config {
|
||||
///
|
||||
/// Defaults to the user.
|
||||
pub fn dbname(&mut self, dbname: &str) -> &mut Config {
|
||||
self.set_param("database", dbname)
|
||||
self.dbname = Some(dbname.to_string());
|
||||
self
|
||||
}
|
||||
|
||||
/// Gets the name of the database to connect to, if one has been configured
|
||||
/// with the `dbname` method.
|
||||
pub fn db_is_set(&self) -> bool {
|
||||
self.database
|
||||
pub fn get_dbname(&self) -> Option<&str> {
|
||||
self.dbname.as_deref()
|
||||
}
|
||||
|
||||
pub fn set_param(&mut self, name: &str, value: &str) -> &mut Config {
|
||||
if name == "database" {
|
||||
self.database = true;
|
||||
} else if name == "user" {
|
||||
self.username = true;
|
||||
}
|
||||
|
||||
self.server_params.insert(name, value);
|
||||
/// Sets command line options used to configure the server.
|
||||
pub fn options(&mut self, options: &str) -> &mut Config {
|
||||
self.options = Some(options.to_string());
|
||||
self
|
||||
}
|
||||
|
||||
/// Gets the command line options used to configure the server, if the
|
||||
/// options have been set with the `options` method.
|
||||
pub fn get_options(&self) -> Option<&str> {
|
||||
self.options.as_deref()
|
||||
}
|
||||
|
||||
/// Sets the value of the `application_name` runtime parameter.
|
||||
pub fn application_name(&mut self, application_name: &str) -> &mut Config {
|
||||
self.application_name = Some(application_name.to_string());
|
||||
self
|
||||
}
|
||||
|
||||
/// Gets the value of the `application_name` runtime parameter, if it has
|
||||
/// been set with the `application_name` method.
|
||||
pub fn get_application_name(&self) -> Option<&str> {
|
||||
self.application_name.as_deref()
|
||||
}
|
||||
|
||||
/// Sets the SSL configuration.
|
||||
///
|
||||
/// Defaults to `prefer`.
|
||||
@@ -177,14 +285,32 @@ impl Config {
|
||||
self.ssl_mode
|
||||
}
|
||||
|
||||
/// Adds a host to the configuration.
|
||||
///
|
||||
/// Multiple hosts can be specified by calling this method multiple times, and each will be tried in order.
|
||||
pub fn host(&mut self, host: &str) -> &mut Config {
|
||||
self.host.push(Host::Tcp(host.to_string()));
|
||||
self
|
||||
}
|
||||
|
||||
/// Gets the hosts that have been added to the configuration with `host`.
|
||||
pub fn get_host(&self) -> &Host {
|
||||
pub fn get_hosts(&self) -> &[Host] {
|
||||
&self.host
|
||||
}
|
||||
|
||||
/// Adds a port to the configuration.
|
||||
///
|
||||
/// Multiple ports can be specified by calling this method multiple times. There must either be no ports, in which
|
||||
/// case the default of 5432 is used, a single port, in which it is used for all hosts, or the same number of ports
|
||||
/// as hosts.
|
||||
pub fn port(&mut self, port: u16) -> &mut Config {
|
||||
self.port.push(port);
|
||||
self
|
||||
}
|
||||
|
||||
/// Gets the ports that have been added to the configuration with `port`.
|
||||
pub fn get_port(&self) -> u16 {
|
||||
self.port
|
||||
pub fn get_ports(&self) -> &[u16] {
|
||||
&self.port
|
||||
}
|
||||
|
||||
/// Sets the timeout applied to socket-level connection attempts.
|
||||
@@ -202,6 +328,23 @@ impl Config {
|
||||
self.connect_timeout.as_ref()
|
||||
}
|
||||
|
||||
/// Sets the requirements of the session.
|
||||
///
|
||||
/// This can be used to connect to the primary server in a clustered database rather than one of the read-only
|
||||
/// secondary servers. Defaults to `Any`.
|
||||
pub fn target_session_attrs(
|
||||
&mut self,
|
||||
target_session_attrs: TargetSessionAttrs,
|
||||
) -> &mut Config {
|
||||
self.target_session_attrs = target_session_attrs;
|
||||
self
|
||||
}
|
||||
|
||||
/// Gets the requirements of the session.
|
||||
pub fn get_target_session_attrs(&self) -> TargetSessionAttrs {
|
||||
self.target_session_attrs
|
||||
}
|
||||
|
||||
/// Sets the channel binding behavior.
|
||||
///
|
||||
/// Defaults to `prefer`.
|
||||
@@ -215,6 +358,121 @@ impl Config {
|
||||
self.channel_binding
|
||||
}
|
||||
|
||||
/// Set replication mode.
|
||||
pub fn replication_mode(&mut self, replication_mode: ReplicationMode) -> &mut Config {
|
||||
self.replication_mode = Some(replication_mode);
|
||||
self
|
||||
}
|
||||
|
||||
/// Get replication mode.
|
||||
pub fn get_replication_mode(&self) -> Option<ReplicationMode> {
|
||||
self.replication_mode
|
||||
}
|
||||
|
||||
/// Set limit for backend messages size.
|
||||
pub fn max_backend_message_size(&mut self, max_backend_message_size: usize) -> &mut Config {
|
||||
self.max_backend_message_size = Some(max_backend_message_size);
|
||||
self
|
||||
}
|
||||
|
||||
/// Get limit for backend messages size.
|
||||
pub fn get_max_backend_message_size(&self) -> Option<usize> {
|
||||
self.max_backend_message_size
|
||||
}
|
||||
|
||||
fn param(&mut self, key: &str, value: &str) -> Result<(), Error> {
|
||||
match key {
|
||||
"user" => {
|
||||
self.user(value);
|
||||
}
|
||||
"password" => {
|
||||
self.password(value);
|
||||
}
|
||||
"dbname" => {
|
||||
self.dbname(value);
|
||||
}
|
||||
"options" => {
|
||||
self.options(value);
|
||||
}
|
||||
"application_name" => {
|
||||
self.application_name(value);
|
||||
}
|
||||
"sslmode" => {
|
||||
let mode = match value {
|
||||
"disable" => SslMode::Disable,
|
||||
"prefer" => SslMode::Prefer,
|
||||
"require" => SslMode::Require,
|
||||
_ => return Err(Error::config_parse(Box::new(InvalidValue("sslmode")))),
|
||||
};
|
||||
self.ssl_mode(mode);
|
||||
}
|
||||
"host" => {
|
||||
for host in value.split(',') {
|
||||
self.host(host);
|
||||
}
|
||||
}
|
||||
"port" => {
|
||||
for port in value.split(',') {
|
||||
let port = if port.is_empty() {
|
||||
5432
|
||||
} else {
|
||||
port.parse()
|
||||
.map_err(|_| Error::config_parse(Box::new(InvalidValue("port"))))?
|
||||
};
|
||||
self.port(port);
|
||||
}
|
||||
}
|
||||
"connect_timeout" => {
|
||||
let timeout = value
|
||||
.parse::<i64>()
|
||||
.map_err(|_| Error::config_parse(Box::new(InvalidValue("connect_timeout"))))?;
|
||||
if timeout > 0 {
|
||||
self.connect_timeout(Duration::from_secs(timeout as u64));
|
||||
}
|
||||
}
|
||||
"target_session_attrs" => {
|
||||
let target_session_attrs = match value {
|
||||
"any" => TargetSessionAttrs::Any,
|
||||
"read-write" => TargetSessionAttrs::ReadWrite,
|
||||
_ => {
|
||||
return Err(Error::config_parse(Box::new(InvalidValue(
|
||||
"target_session_attrs",
|
||||
))));
|
||||
}
|
||||
};
|
||||
self.target_session_attrs(target_session_attrs);
|
||||
}
|
||||
"channel_binding" => {
|
||||
let channel_binding = match value {
|
||||
"disable" => ChannelBinding::Disable,
|
||||
"prefer" => ChannelBinding::Prefer,
|
||||
"require" => ChannelBinding::Require,
|
||||
_ => {
|
||||
return Err(Error::config_parse(Box::new(InvalidValue(
|
||||
"channel_binding",
|
||||
))))
|
||||
}
|
||||
};
|
||||
self.channel_binding(channel_binding);
|
||||
}
|
||||
"max_backend_message_size" => {
|
||||
let limit = value.parse::<usize>().map_err(|_| {
|
||||
Error::config_parse(Box::new(InvalidValue("max_backend_message_size")))
|
||||
})?;
|
||||
if limit > 0 {
|
||||
self.max_backend_message_size(limit);
|
||||
}
|
||||
}
|
||||
key => {
|
||||
return Err(Error::config_parse(Box::new(UnknownOption(
|
||||
key.to_string(),
|
||||
))));
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Opens a connection to a PostgreSQL database.
|
||||
///
|
||||
/// Requires the `runtime` Cargo feature (enabled by default).
|
||||
@@ -241,6 +499,17 @@ impl Config {
|
||||
}
|
||||
}
|
||||
|
||||
impl FromStr for Config {
|
||||
type Err = Error;
|
||||
|
||||
fn from_str(s: &str) -> Result<Config, Error> {
|
||||
match UrlParser::parse(s)? {
|
||||
Some(config) => Ok(config),
|
||||
None => Parser::parse(s),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Omit password from debug output
|
||||
impl fmt::Debug for Config {
|
||||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||
@@ -252,13 +521,375 @@ impl fmt::Debug for Config {
|
||||
}
|
||||
|
||||
f.debug_struct("Config")
|
||||
.field("user", &self.user)
|
||||
.field("password", &self.password.as_ref().map(|_| Redaction {}))
|
||||
.field("dbname", &self.dbname)
|
||||
.field("options", &self.options)
|
||||
.field("application_name", &self.application_name)
|
||||
.field("ssl_mode", &self.ssl_mode)
|
||||
.field("host", &self.host)
|
||||
.field("port", &self.port)
|
||||
.field("connect_timeout", &self.connect_timeout)
|
||||
.field("target_session_attrs", &self.target_session_attrs)
|
||||
.field("channel_binding", &self.channel_binding)
|
||||
.field("server_params", &self.server_params)
|
||||
.field("replication", &self.replication_mode)
|
||||
.finish()
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
struct UnknownOption(String);
|
||||
|
||||
impl fmt::Display for UnknownOption {
|
||||
fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||
write!(fmt, "unknown option `{}`", self.0)
|
||||
}
|
||||
}
|
||||
|
||||
impl error::Error for UnknownOption {}
|
||||
|
||||
#[derive(Debug)]
|
||||
struct InvalidValue(&'static str);
|
||||
|
||||
impl fmt::Display for InvalidValue {
|
||||
fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||
write!(fmt, "invalid value for option `{}`", self.0)
|
||||
}
|
||||
}
|
||||
|
||||
impl error::Error for InvalidValue {}
|
||||
|
||||
struct Parser<'a> {
|
||||
s: &'a str,
|
||||
it: iter::Peekable<str::CharIndices<'a>>,
|
||||
}
|
||||
|
||||
impl<'a> Parser<'a> {
|
||||
fn parse(s: &'a str) -> Result<Config, Error> {
|
||||
let mut parser = Parser {
|
||||
s,
|
||||
it: s.char_indices().peekable(),
|
||||
};
|
||||
|
||||
let mut config = Config::new();
|
||||
|
||||
while let Some((key, value)) = parser.parameter()? {
|
||||
config.param(key, &value)?;
|
||||
}
|
||||
|
||||
Ok(config)
|
||||
}
|
||||
|
||||
fn skip_ws(&mut self) {
|
||||
self.take_while(char::is_whitespace);
|
||||
}
|
||||
|
||||
fn take_while<F>(&mut self, f: F) -> &'a str
|
||||
where
|
||||
F: Fn(char) -> bool,
|
||||
{
|
||||
let start = match self.it.peek() {
|
||||
Some(&(i, _)) => i,
|
||||
None => return "",
|
||||
};
|
||||
|
||||
loop {
|
||||
match self.it.peek() {
|
||||
Some(&(_, c)) if f(c) => {
|
||||
self.it.next();
|
||||
}
|
||||
Some(&(i, _)) => return &self.s[start..i],
|
||||
None => return &self.s[start..],
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn eat(&mut self, target: char) -> Result<(), Error> {
|
||||
match self.it.next() {
|
||||
Some((_, c)) if c == target => Ok(()),
|
||||
Some((i, c)) => {
|
||||
let m = format!(
|
||||
"unexpected character at byte {}: expected `{}` but got `{}`",
|
||||
i, target, c
|
||||
);
|
||||
Err(Error::config_parse(m.into()))
|
||||
}
|
||||
None => Err(Error::config_parse("unexpected EOF".into())),
|
||||
}
|
||||
}
|
||||
|
||||
fn eat_if(&mut self, target: char) -> bool {
|
||||
match self.it.peek() {
|
||||
Some(&(_, c)) if c == target => {
|
||||
self.it.next();
|
||||
true
|
||||
}
|
||||
_ => false,
|
||||
}
|
||||
}
|
||||
|
||||
fn keyword(&mut self) -> Option<&'a str> {
|
||||
let s = self.take_while(|c| match c {
|
||||
c if c.is_whitespace() => false,
|
||||
'=' => false,
|
||||
_ => true,
|
||||
});
|
||||
|
||||
if s.is_empty() {
|
||||
None
|
||||
} else {
|
||||
Some(s)
|
||||
}
|
||||
}
|
||||
|
||||
fn value(&mut self) -> Result<String, Error> {
|
||||
let value = if self.eat_if('\'') {
|
||||
let value = self.quoted_value()?;
|
||||
self.eat('\'')?;
|
||||
value
|
||||
} else {
|
||||
self.simple_value()?
|
||||
};
|
||||
|
||||
Ok(value)
|
||||
}
|
||||
|
||||
fn simple_value(&mut self) -> Result<String, Error> {
|
||||
let mut value = String::new();
|
||||
|
||||
while let Some(&(_, c)) = self.it.peek() {
|
||||
if c.is_whitespace() {
|
||||
break;
|
||||
}
|
||||
|
||||
self.it.next();
|
||||
if c == '\\' {
|
||||
if let Some((_, c2)) = self.it.next() {
|
||||
value.push(c2);
|
||||
}
|
||||
} else {
|
||||
value.push(c);
|
||||
}
|
||||
}
|
||||
|
||||
if value.is_empty() {
|
||||
return Err(Error::config_parse("unexpected EOF".into()));
|
||||
}
|
||||
|
||||
Ok(value)
|
||||
}
|
||||
|
||||
fn quoted_value(&mut self) -> Result<String, Error> {
|
||||
let mut value = String::new();
|
||||
|
||||
while let Some(&(_, c)) = self.it.peek() {
|
||||
if c == '\'' {
|
||||
return Ok(value);
|
||||
}
|
||||
|
||||
self.it.next();
|
||||
if c == '\\' {
|
||||
if let Some((_, c2)) = self.it.next() {
|
||||
value.push(c2);
|
||||
}
|
||||
} else {
|
||||
value.push(c);
|
||||
}
|
||||
}
|
||||
|
||||
Err(Error::config_parse(
|
||||
"unterminated quoted connection parameter value".into(),
|
||||
))
|
||||
}
|
||||
|
||||
fn parameter(&mut self) -> Result<Option<(&'a str, String)>, Error> {
|
||||
self.skip_ws();
|
||||
let keyword = match self.keyword() {
|
||||
Some(keyword) => keyword,
|
||||
None => return Ok(None),
|
||||
};
|
||||
self.skip_ws();
|
||||
self.eat('=')?;
|
||||
self.skip_ws();
|
||||
let value = self.value()?;
|
||||
|
||||
Ok(Some((keyword, value)))
|
||||
}
|
||||
}
|
||||
|
||||
// This is a pretty sloppy "URL" parser, but it matches the behavior of libpq, where things really aren't very strict
|
||||
struct UrlParser<'a> {
|
||||
s: &'a str,
|
||||
config: Config,
|
||||
}
|
||||
|
||||
impl<'a> UrlParser<'a> {
|
||||
fn parse(s: &'a str) -> Result<Option<Config>, Error> {
|
||||
let s = match Self::remove_url_prefix(s) {
|
||||
Some(s) => s,
|
||||
None => return Ok(None),
|
||||
};
|
||||
|
||||
let mut parser = UrlParser {
|
||||
s,
|
||||
config: Config::new(),
|
||||
};
|
||||
|
||||
parser.parse_credentials()?;
|
||||
parser.parse_host()?;
|
||||
parser.parse_path()?;
|
||||
parser.parse_params()?;
|
||||
|
||||
Ok(Some(parser.config))
|
||||
}
|
||||
|
||||
fn remove_url_prefix(s: &str) -> Option<&str> {
|
||||
for prefix in &["postgres://", "postgresql://"] {
|
||||
if let Some(stripped) = s.strip_prefix(prefix) {
|
||||
return Some(stripped);
|
||||
}
|
||||
}
|
||||
|
||||
None
|
||||
}
|
||||
|
||||
fn take_until(&mut self, end: &[char]) -> Option<&'a str> {
|
||||
match self.s.find(end) {
|
||||
Some(pos) => {
|
||||
let (head, tail) = self.s.split_at(pos);
|
||||
self.s = tail;
|
||||
Some(head)
|
||||
}
|
||||
None => None,
|
||||
}
|
||||
}
|
||||
|
||||
fn take_all(&mut self) -> &'a str {
|
||||
mem::take(&mut self.s)
|
||||
}
|
||||
|
||||
fn eat_byte(&mut self) {
|
||||
self.s = &self.s[1..];
|
||||
}
|
||||
|
||||
fn parse_credentials(&mut self) -> Result<(), Error> {
|
||||
let creds = match self.take_until(&['@']) {
|
||||
Some(creds) => creds,
|
||||
None => return Ok(()),
|
||||
};
|
||||
self.eat_byte();
|
||||
|
||||
let mut it = creds.splitn(2, ':');
|
||||
let user = self.decode(it.next().unwrap())?;
|
||||
self.config.user(&user);
|
||||
|
||||
if let Some(password) = it.next() {
|
||||
let password = Cow::from(percent_encoding::percent_decode(password.as_bytes()));
|
||||
self.config.password(password);
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn parse_host(&mut self) -> Result<(), Error> {
|
||||
let host = match self.take_until(&['/', '?']) {
|
||||
Some(host) => host,
|
||||
None => self.take_all(),
|
||||
};
|
||||
|
||||
if host.is_empty() {
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
for chunk in host.split(',') {
|
||||
let (host, port) = if chunk.starts_with('[') {
|
||||
let idx = match chunk.find(']') {
|
||||
Some(idx) => idx,
|
||||
None => return Err(Error::config_parse(InvalidValue("host").into())),
|
||||
};
|
||||
|
||||
let host = &chunk[1..idx];
|
||||
let remaining = &chunk[idx + 1..];
|
||||
let port = if let Some(port) = remaining.strip_prefix(':') {
|
||||
Some(port)
|
||||
} else if remaining.is_empty() {
|
||||
None
|
||||
} else {
|
||||
return Err(Error::config_parse(InvalidValue("host").into()));
|
||||
};
|
||||
|
||||
(host, port)
|
||||
} else {
|
||||
let mut it = chunk.splitn(2, ':');
|
||||
(it.next().unwrap(), it.next())
|
||||
};
|
||||
|
||||
self.host_param(host)?;
|
||||
let port = self.decode(port.unwrap_or("5432"))?;
|
||||
self.config.param("port", &port)?;
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn parse_path(&mut self) -> Result<(), Error> {
|
||||
if !self.s.starts_with('/') {
|
||||
return Ok(());
|
||||
}
|
||||
self.eat_byte();
|
||||
|
||||
let dbname = match self.take_until(&['?']) {
|
||||
Some(dbname) => dbname,
|
||||
None => self.take_all(),
|
||||
};
|
||||
|
||||
if !dbname.is_empty() {
|
||||
self.config.dbname(&self.decode(dbname)?);
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn parse_params(&mut self) -> Result<(), Error> {
|
||||
if !self.s.starts_with('?') {
|
||||
return Ok(());
|
||||
}
|
||||
self.eat_byte();
|
||||
|
||||
while !self.s.is_empty() {
|
||||
let key = match self.take_until(&['=']) {
|
||||
Some(key) => self.decode(key)?,
|
||||
None => return Err(Error::config_parse("unterminated parameter".into())),
|
||||
};
|
||||
self.eat_byte();
|
||||
|
||||
let value = match self.take_until(&['&']) {
|
||||
Some(value) => {
|
||||
self.eat_byte();
|
||||
value
|
||||
}
|
||||
None => self.take_all(),
|
||||
};
|
||||
|
||||
if key == "host" {
|
||||
self.host_param(value)?;
|
||||
} else {
|
||||
let value = self.decode(value)?;
|
||||
self.config.param(&key, &value)?;
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn host_param(&mut self, s: &str) -> Result<(), Error> {
|
||||
let s = self.decode(s)?;
|
||||
self.config.param("host", &s)
|
||||
}
|
||||
|
||||
fn decode(&self, s: &'a str) -> Result<Cow<'a, str>, Error> {
|
||||
percent_encoding::percent_decode(s.as_bytes())
|
||||
.decode_utf8()
|
||||
.map_err(|e| Error::config_parse(e.into()))
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,11 +1,14 @@
|
||||
use crate::client::SocketConfig;
|
||||
use crate::codec::BackendMessage;
|
||||
use crate::config::Host;
|
||||
use crate::config::{Host, TargetSessionAttrs};
|
||||
use crate::connect_raw::connect_raw;
|
||||
use crate::connect_socket::connect_socket;
|
||||
use crate::tls::{MakeTlsConnect, TlsConnect};
|
||||
use crate::{Client, Config, Connection, Error, RawConnection};
|
||||
use crate::{Client, Config, Connection, Error, RawConnection, SimpleQueryMessage};
|
||||
use futures_util::{future, pin_mut, Future, FutureExt, Stream};
|
||||
use postgres_protocol2::message::backend::Message;
|
||||
use std::io;
|
||||
use std::task::Poll;
|
||||
use tokio::net::TcpStream;
|
||||
use tokio::sync::mpsc;
|
||||
|
||||
@@ -16,18 +19,38 @@ pub async fn connect<T>(
|
||||
where
|
||||
T: MakeTlsConnect<TcpStream>,
|
||||
{
|
||||
let hostname = match &config.host {
|
||||
Host::Tcp(host) => host.as_str(),
|
||||
};
|
||||
|
||||
let tls = tls
|
||||
.make_tls_connect(hostname)
|
||||
.map_err(|e| Error::tls(e.into()))?;
|
||||
|
||||
match connect_once(&config.host, config.port, tls, config).await {
|
||||
Ok((client, connection)) => Ok((client, connection)),
|
||||
Err(e) => Err(e),
|
||||
if config.host.is_empty() {
|
||||
return Err(Error::config("host missing".into()));
|
||||
}
|
||||
|
||||
if config.port.len() > 1 && config.port.len() != config.host.len() {
|
||||
return Err(Error::config("invalid number of ports".into()));
|
||||
}
|
||||
|
||||
let mut error = None;
|
||||
for (i, host) in config.host.iter().enumerate() {
|
||||
let port = config
|
||||
.port
|
||||
.get(i)
|
||||
.or_else(|| config.port.first())
|
||||
.copied()
|
||||
.unwrap_or(5432);
|
||||
|
||||
let hostname = match host {
|
||||
Host::Tcp(host) => host.as_str(),
|
||||
};
|
||||
|
||||
let tls = tls
|
||||
.make_tls_connect(hostname)
|
||||
.map_err(|e| Error::tls(e.into()))?;
|
||||
|
||||
match connect_once(host, port, tls, config).await {
|
||||
Ok((client, connection)) => return Ok((client, connection)),
|
||||
Err(e) => error = Some(e),
|
||||
}
|
||||
}
|
||||
|
||||
Err(error.unwrap())
|
||||
}
|
||||
|
||||
async fn connect_once<T>(
|
||||
@@ -69,7 +92,47 @@ where
|
||||
.map(|m| BackendMessage::Async(Message::NoticeResponse(m)))
|
||||
.collect();
|
||||
|
||||
let connection = Connection::new(stream, delayed, parameters, receiver);
|
||||
let mut connection = Connection::new(stream, delayed, parameters, receiver);
|
||||
|
||||
if let TargetSessionAttrs::ReadWrite = config.target_session_attrs {
|
||||
let rows = client.simple_query_raw("SHOW transaction_read_only");
|
||||
pin_mut!(rows);
|
||||
|
||||
let rows = future::poll_fn(|cx| {
|
||||
if connection.poll_unpin(cx)?.is_ready() {
|
||||
return Poll::Ready(Err(Error::closed()));
|
||||
}
|
||||
|
||||
rows.as_mut().poll(cx)
|
||||
})
|
||||
.await?;
|
||||
pin_mut!(rows);
|
||||
|
||||
loop {
|
||||
let next = future::poll_fn(|cx| {
|
||||
if connection.poll_unpin(cx)?.is_ready() {
|
||||
return Poll::Ready(Some(Err(Error::closed())));
|
||||
}
|
||||
|
||||
rows.as_mut().poll_next(cx)
|
||||
});
|
||||
|
||||
match next.await.transpose()? {
|
||||
Some(SimpleQueryMessage::Row(row)) => {
|
||||
if row.try_get(0)? == Some("on") {
|
||||
return Err(Error::connect(io::Error::new(
|
||||
io::ErrorKind::PermissionDenied,
|
||||
"database does not allow writes",
|
||||
)));
|
||||
} else {
|
||||
break;
|
||||
}
|
||||
}
|
||||
Some(_) => {}
|
||||
None => return Err(Error::unexpected_message()),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Ok((client, connection))
|
||||
}
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
use crate::codec::{BackendMessage, BackendMessages, FrontendMessage, PostgresCodec};
|
||||
use crate::config::{self, AuthKeys, Config};
|
||||
use crate::config::{self, AuthKeys, Config, ReplicationMode};
|
||||
use crate::connect_tls::connect_tls;
|
||||
use crate::maybe_tls_stream::MaybeTlsStream;
|
||||
use crate::tls::{TlsConnect, TlsStream};
|
||||
@@ -7,6 +7,7 @@ use crate::Error;
|
||||
use bytes::BytesMut;
|
||||
use fallible_iterator::FallibleIterator;
|
||||
use futures_util::{ready, Sink, SinkExt, Stream, TryStreamExt};
|
||||
use postgres_protocol2::authentication;
|
||||
use postgres_protocol2::authentication::sasl;
|
||||
use postgres_protocol2::authentication::sasl::ScramSha256;
|
||||
use postgres_protocol2::message::backend::{AuthenticationSaslBody, Message, NoticeResponseBody};
|
||||
@@ -96,7 +97,12 @@ where
|
||||
let stream = connect_tls(stream, config.ssl_mode, tls).await?;
|
||||
|
||||
let mut stream = StartupStream {
|
||||
inner: Framed::new(stream, PostgresCodec),
|
||||
inner: Framed::new(
|
||||
stream,
|
||||
PostgresCodec {
|
||||
max_message_size: config.max_backend_message_size,
|
||||
},
|
||||
),
|
||||
buf: BackendMessages::empty(),
|
||||
delayed_notice: Vec::new(),
|
||||
};
|
||||
@@ -119,8 +125,28 @@ where
|
||||
S: AsyncRead + AsyncWrite + Unpin,
|
||||
T: AsyncRead + AsyncWrite + Unpin,
|
||||
{
|
||||
let mut params = vec![("client_encoding", "UTF8")];
|
||||
if let Some(user) = &config.user {
|
||||
params.push(("user", &**user));
|
||||
}
|
||||
if let Some(dbname) = &config.dbname {
|
||||
params.push(("database", &**dbname));
|
||||
}
|
||||
if let Some(options) = &config.options {
|
||||
params.push(("options", &**options));
|
||||
}
|
||||
if let Some(application_name) = &config.application_name {
|
||||
params.push(("application_name", &**application_name));
|
||||
}
|
||||
if let Some(replication_mode) = &config.replication_mode {
|
||||
match replication_mode {
|
||||
ReplicationMode::Physical => params.push(("replication", "true")),
|
||||
ReplicationMode::Logical => params.push(("replication", "database")),
|
||||
}
|
||||
}
|
||||
|
||||
let mut buf = BytesMut::new();
|
||||
frontend::startup_message(&config.server_params, &mut buf).map_err(Error::encode)?;
|
||||
frontend::startup_message(params, &mut buf).map_err(Error::encode)?;
|
||||
|
||||
stream
|
||||
.send(FrontendMessage::Raw(buf.freeze()))
|
||||
@@ -148,11 +174,25 @@ where
|
||||
|
||||
authenticate_password(stream, pass).await?;
|
||||
}
|
||||
Some(Message::AuthenticationMd5Password(body)) => {
|
||||
can_skip_channel_binding(config)?;
|
||||
|
||||
let user = config
|
||||
.user
|
||||
.as_ref()
|
||||
.ok_or_else(|| Error::config("user missing".into()))?;
|
||||
let pass = config
|
||||
.password
|
||||
.as_ref()
|
||||
.ok_or_else(|| Error::config("password missing".into()))?;
|
||||
|
||||
let output = authentication::md5_hash(user.as_bytes(), pass, body.salt());
|
||||
authenticate_password(stream, output.as_bytes()).await?;
|
||||
}
|
||||
Some(Message::AuthenticationSasl(body)) => {
|
||||
authenticate_sasl(stream, body, config).await?;
|
||||
}
|
||||
Some(Message::AuthenticationMd5Password)
|
||||
| Some(Message::AuthenticationKerberosV5)
|
||||
Some(Message::AuthenticationKerberosV5)
|
||||
| Some(Message::AuthenticationScmCredential)
|
||||
| Some(Message::AuthenticationGss)
|
||||
| Some(Message::AuthenticationSspi) => {
|
||||
|
||||
@@ -349,6 +349,7 @@ enum Kind {
|
||||
Parse,
|
||||
Encode,
|
||||
Authentication,
|
||||
ConfigParse,
|
||||
Config,
|
||||
Connect,
|
||||
Timeout,
|
||||
@@ -385,6 +386,7 @@ impl fmt::Display for Error {
|
||||
Kind::Parse => fmt.write_str("error parsing response from server")?,
|
||||
Kind::Encode => fmt.write_str("error encoding message to server")?,
|
||||
Kind::Authentication => fmt.write_str("authentication error")?,
|
||||
Kind::ConfigParse => fmt.write_str("invalid connection string")?,
|
||||
Kind::Config => fmt.write_str("invalid configuration")?,
|
||||
Kind::Connect => fmt.write_str("error connecting to server")?,
|
||||
Kind::Timeout => fmt.write_str("timeout waiting for server")?,
|
||||
@@ -480,6 +482,10 @@ impl Error {
|
||||
Error::new(Kind::Authentication, Some(e))
|
||||
}
|
||||
|
||||
pub(crate) fn config_parse(e: Box<dyn error::Error + Sync + Send>) -> Error {
|
||||
Error::new(Kind::ConfigParse, Some(e))
|
||||
}
|
||||
|
||||
pub(crate) fn config(e: Box<dyn error::Error + Sync + Send>) -> Error {
|
||||
Error::new(Kind::Config, Some(e))
|
||||
}
|
||||
|
||||
@@ -13,12 +13,14 @@ pub use crate::query::RowStream;
|
||||
pub use crate::row::{Row, SimpleQueryRow};
|
||||
pub use crate::simple_query::SimpleQueryStream;
|
||||
pub use crate::statement::{Column, Statement};
|
||||
use crate::tls::MakeTlsConnect;
|
||||
pub use crate::tls::NoTls;
|
||||
pub use crate::to_statement::ToStatement;
|
||||
pub use crate::transaction::Transaction;
|
||||
pub use crate::transaction_builder::{IsolationLevel, TransactionBuilder};
|
||||
use crate::types::ToSql;
|
||||
use postgres_protocol2::message::backend::ReadyForQueryBody;
|
||||
use tokio::net::TcpStream;
|
||||
|
||||
/// After executing a query, the connection will be in one of these states
|
||||
#[derive(Clone, Copy, Debug, PartialEq)]
|
||||
@@ -70,6 +72,24 @@ mod transaction;
|
||||
mod transaction_builder;
|
||||
pub mod types;
|
||||
|
||||
/// A convenience function which parses a connection string and connects to the database.
|
||||
///
|
||||
/// See the documentation for [`Config`] for details on the connection string format.
|
||||
///
|
||||
/// Requires the `runtime` Cargo feature (enabled by default).
|
||||
///
|
||||
/// [`Config`]: config/struct.Config.html
|
||||
pub async fn connect<T>(
|
||||
config: &str,
|
||||
tls: T,
|
||||
) -> Result<(Client, Connection<TcpStream, T::Stream>), Error>
|
||||
where
|
||||
T: MakeTlsConnect<TcpStream>,
|
||||
{
|
||||
let config = config.parse::<Config>()?;
|
||||
config.connect(tls).await
|
||||
}
|
||||
|
||||
/// An asynchronous notification.
|
||||
#[derive(Clone, Debug)]
|
||||
pub struct Notification {
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
pub mod heavier_once_cell;
|
||||
|
||||
pub mod duplex;
|
||||
pub mod gate;
|
||||
|
||||
pub mod spsc_fold;
|
||||
|
||||
@@ -1 +0,0 @@
|
||||
pub mod mpsc;
|
||||
@@ -1,36 +0,0 @@
|
||||
use tokio::sync::mpsc;
|
||||
|
||||
/// A bi-directional channel.
|
||||
pub struct Duplex<S, R> {
|
||||
pub tx: mpsc::Sender<S>,
|
||||
pub rx: mpsc::Receiver<R>,
|
||||
}
|
||||
|
||||
/// Creates a bi-directional channel.
|
||||
///
|
||||
/// The channel will buffer up to the provided number of messages. Once the buffer is full,
|
||||
/// attempts to send new messages will wait until a message is received from the channel.
|
||||
/// The provided buffer capacity must be at least 1.
|
||||
pub fn channel<A: Send, B: Send>(buffer: usize) -> (Duplex<A, B>, Duplex<B, A>) {
|
||||
let (tx_a, rx_a) = mpsc::channel::<A>(buffer);
|
||||
let (tx_b, rx_b) = mpsc::channel::<B>(buffer);
|
||||
|
||||
(Duplex { tx: tx_a, rx: rx_b }, Duplex { tx: tx_b, rx: rx_a })
|
||||
}
|
||||
|
||||
impl<S: Send, R: Send> Duplex<S, R> {
|
||||
/// Sends a value, waiting until there is capacity.
|
||||
///
|
||||
/// A successful send occurs when it is determined that the other end of the channel has not hung up already.
|
||||
pub async fn send(&self, x: S) -> Result<(), mpsc::error::SendError<S>> {
|
||||
self.tx.send(x).await
|
||||
}
|
||||
|
||||
/// Receives the next value for this receiver.
|
||||
///
|
||||
/// This method returns `None` if the channel has been closed and there are
|
||||
/// no remaining messages in the channel's buffer.
|
||||
pub async fn recv(&mut self) -> Option<R> {
|
||||
self.rx.recv().await
|
||||
}
|
||||
}
|
||||
@@ -112,38 +112,30 @@ impl MetadataRecord {
|
||||
};
|
||||
|
||||
// Next, filter the metadata record by shard.
|
||||
match metadata_record {
|
||||
Some(
|
||||
MetadataRecord::Heapam(HeapamRecord::ClearVmBits(ref mut clear_vm_bits))
|
||||
| MetadataRecord::Neonrmgr(NeonrmgrRecord::ClearVmBits(ref mut clear_vm_bits)),
|
||||
) => {
|
||||
// Route VM page updates to the shards that own them. VM pages are stored in the VM fork
|
||||
// of the main relation. These are sharded and managed just like regular relation pages.
|
||||
// See: https://github.com/neondatabase/neon/issues/9855
|
||||
let is_local_vm_page = |heap_blk| {
|
||||
let vm_blk = pg_constants::HEAPBLK_TO_MAPBLOCK(heap_blk);
|
||||
shard.is_key_local(&rel_block_to_key(clear_vm_bits.vm_rel, vm_blk))
|
||||
};
|
||||
// Send the old and new VM page updates to their respective shards.
|
||||
clear_vm_bits.old_heap_blkno = clear_vm_bits
|
||||
.old_heap_blkno
|
||||
.filter(|&blkno| is_local_vm_page(blkno));
|
||||
clear_vm_bits.new_heap_blkno = clear_vm_bits
|
||||
.new_heap_blkno
|
||||
.filter(|&blkno| is_local_vm_page(blkno));
|
||||
// If neither VM page belongs to this shard, discard the record.
|
||||
if clear_vm_bits.old_heap_blkno.is_none() && clear_vm_bits.new_heap_blkno.is_none()
|
||||
{
|
||||
metadata_record = None
|
||||
}
|
||||
|
||||
// Route VM page updates to the shards that own them. VM pages are stored in the VM fork
|
||||
// of the main relation. These are sharded and managed just like regular relation pages.
|
||||
// See: https://github.com/neondatabase/neon/issues/9855
|
||||
if let Some(
|
||||
MetadataRecord::Heapam(HeapamRecord::ClearVmBits(ref mut clear_vm_bits))
|
||||
| MetadataRecord::Neonrmgr(NeonrmgrRecord::ClearVmBits(ref mut clear_vm_bits)),
|
||||
) = metadata_record
|
||||
{
|
||||
let is_local_vm_page = |heap_blk| {
|
||||
let vm_blk = pg_constants::HEAPBLK_TO_MAPBLOCK(heap_blk);
|
||||
shard.is_key_local(&rel_block_to_key(clear_vm_bits.vm_rel, vm_blk))
|
||||
};
|
||||
// Send the old and new VM page updates to their respective shards.
|
||||
clear_vm_bits.old_heap_blkno = clear_vm_bits
|
||||
.old_heap_blkno
|
||||
.filter(|&blkno| is_local_vm_page(blkno));
|
||||
clear_vm_bits.new_heap_blkno = clear_vm_bits
|
||||
.new_heap_blkno
|
||||
.filter(|&blkno| is_local_vm_page(blkno));
|
||||
// If neither VM page belongs to this shard, discard the record.
|
||||
if clear_vm_bits.old_heap_blkno.is_none() && clear_vm_bits.new_heap_blkno.is_none() {
|
||||
metadata_record = None
|
||||
}
|
||||
Some(MetadataRecord::LogicalMessage(LogicalMessageRecord::Put(_))) => {
|
||||
// Filter LogicalMessage records (AUX files) to only be stored on shard zero
|
||||
if !shard.is_shard_zero() {
|
||||
metadata_record = None;
|
||||
}
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
|
||||
Ok(metadata_record)
|
||||
|
||||
@@ -62,8 +62,10 @@ async fn ingest(
|
||||
let ctx = RequestContext::new(TaskKind::DebugTool, DownloadBehavior::Error);
|
||||
|
||||
let gate = utils::sync::gate::Gate::default();
|
||||
let entered = gate.enter().unwrap();
|
||||
|
||||
let layer = InMemoryLayer::create(conf, timeline_id, tenant_shard_id, lsn, &gate, &ctx).await?;
|
||||
let layer =
|
||||
InMemoryLayer::create(conf, timeline_id, tenant_shard_id, lsn, entered, &ctx).await?;
|
||||
|
||||
let data = Value::Image(Bytes::from(vec![0u8; put_size]));
|
||||
let data_ser_size = data.serialized_size().unwrap() as usize;
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
use std::{collections::HashMap, error::Error as _};
|
||||
use std::collections::HashMap;
|
||||
|
||||
use bytes::Bytes;
|
||||
use detach_ancestor::AncestorDetached;
|
||||
@@ -25,10 +25,10 @@ pub struct Client {
|
||||
|
||||
#[derive(thiserror::Error, Debug)]
|
||||
pub enum Error {
|
||||
#[error("send request: {0}{}", .0.source().map(|e| format!(": {e}")).unwrap_or_default())]
|
||||
#[error("send request: {0}")]
|
||||
SendRequest(reqwest::Error),
|
||||
|
||||
#[error("receive body: {0}{}", .0.source().map(|e| format!(": {e}")).unwrap_or_default())]
|
||||
#[error("receive body: {0}")]
|
||||
ReceiveBody(reqwest::Error),
|
||||
|
||||
#[error("receive error body: {0}")]
|
||||
|
||||
@@ -636,59 +636,45 @@ fn start_pageserver(
|
||||
tokio::net::TcpListener::from_std(pageserver_listener).context("create tokio listener")?
|
||||
});
|
||||
|
||||
// All started up! Now just sit and wait for shutdown signal.
|
||||
BACKGROUND_RUNTIME.block_on(async move {
|
||||
let signal_token = CancellationToken::new();
|
||||
let signal_cancel = signal_token.child_token();
|
||||
let mut shutdown_pageserver = Some(shutdown_pageserver.drop_guard());
|
||||
|
||||
// Spawn signal handlers. Runs in a loop since we want to be responsive to multiple signals
|
||||
// even after triggering shutdown (e.g. a SIGQUIT after a slow SIGTERM shutdown). See:
|
||||
// https://github.com/neondatabase/neon/issues/9740.
|
||||
tokio::spawn(async move {
|
||||
// All started up! Now just sit and wait for shutdown signal.
|
||||
|
||||
{
|
||||
BACKGROUND_RUNTIME.block_on(async move {
|
||||
let mut sigint = tokio::signal::unix::signal(SignalKind::interrupt()).unwrap();
|
||||
let mut sigterm = tokio::signal::unix::signal(SignalKind::terminate()).unwrap();
|
||||
let mut sigquit = tokio::signal::unix::signal(SignalKind::quit()).unwrap();
|
||||
|
||||
loop {
|
||||
let signal = tokio::select! {
|
||||
_ = sigquit.recv() => {
|
||||
info!("Got signal SIGQUIT. Terminating in immediate shutdown mode.");
|
||||
std::process::exit(111);
|
||||
}
|
||||
_ = sigint.recv() => "SIGINT",
|
||||
_ = sigterm.recv() => "SIGTERM",
|
||||
};
|
||||
|
||||
if !signal_token.is_cancelled() {
|
||||
info!("Got signal {signal}. Terminating gracefully in fast shutdown mode.");
|
||||
signal_token.cancel();
|
||||
} else {
|
||||
info!("Got signal {signal}. Already shutting down.");
|
||||
let signal = tokio::select! {
|
||||
_ = sigquit.recv() => {
|
||||
info!("Got signal SIGQUIT. Terminating in immediate shutdown mode",);
|
||||
std::process::exit(111);
|
||||
}
|
||||
}
|
||||
});
|
||||
_ = sigint.recv() => { "SIGINT" },
|
||||
_ = sigterm.recv() => { "SIGTERM" },
|
||||
};
|
||||
|
||||
// Wait for cancellation signal and shut down the pageserver.
|
||||
//
|
||||
// This cancels the `shutdown_pageserver` cancellation tree. Right now that tree doesn't
|
||||
// reach very far, and `task_mgr` is used instead. The plan is to change that over time.
|
||||
signal_cancel.cancelled().await;
|
||||
info!("Got signal {signal}. Terminating gracefully in fast shutdown mode",);
|
||||
|
||||
shutdown_pageserver.cancel();
|
||||
pageserver::shutdown_pageserver(
|
||||
http_endpoint_listener,
|
||||
page_service,
|
||||
consumption_metrics_tasks,
|
||||
disk_usage_eviction_task,
|
||||
&tenant_manager,
|
||||
background_purges,
|
||||
deletion_queue.clone(),
|
||||
secondary_controller_tasks,
|
||||
0,
|
||||
)
|
||||
.await;
|
||||
unreachable!();
|
||||
})
|
||||
// This cancels the `shutdown_pageserver` cancellation tree.
|
||||
// Right now that tree doesn't reach very far, and `task_mgr` is used instead.
|
||||
// The plan is to change that over time.
|
||||
shutdown_pageserver.take();
|
||||
pageserver::shutdown_pageserver(
|
||||
http_endpoint_listener,
|
||||
page_service,
|
||||
consumption_metrics_tasks,
|
||||
disk_usage_eviction_task,
|
||||
&tenant_manager,
|
||||
background_purges,
|
||||
deletion_queue.clone(),
|
||||
secondary_controller_tasks,
|
||||
0,
|
||||
)
|
||||
.await;
|
||||
unreachable!()
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
async fn create_remote_storage_client(
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
use std::error::Error as _;
|
||||
use std::time::SystemTime;
|
||||
|
||||
use chrono::{DateTime, Utc};
|
||||
@@ -351,11 +350,7 @@ impl std::fmt::Display for UploadError {
|
||||
|
||||
match self {
|
||||
Rejected(code) => write!(f, "server rejected the metrics with {code}"),
|
||||
Reqwest(e) => write!(
|
||||
f,
|
||||
"request failed: {e}{}",
|
||||
e.source().map(|e| format!(": {e}")).unwrap_or_default()
|
||||
),
|
||||
Reqwest(e) => write!(f, "request failed: {e}"),
|
||||
Cancelled => write!(f, "cancelled"),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -115,10 +115,6 @@ impl ControllerUpcallClient {
|
||||
|
||||
Ok(res)
|
||||
}
|
||||
|
||||
pub(crate) fn base_url(&self) -> &Url {
|
||||
&self.base_url
|
||||
}
|
||||
}
|
||||
|
||||
impl ControlPlaneGenerationsApi for ControllerUpcallClient {
|
||||
@@ -195,15 +191,13 @@ impl ControlPlaneGenerationsApi for ControllerUpcallClient {
|
||||
|
||||
let request = ReAttachRequest {
|
||||
node_id: self.node_id,
|
||||
register: register.clone(),
|
||||
register,
|
||||
};
|
||||
|
||||
let response: ReAttachResponse = self.retry_http_forever(&re_attach_path, request).await?;
|
||||
tracing::info!(
|
||||
"Received re-attach response with {} tenants (node {}, register: {:?})",
|
||||
response.tenants.len(),
|
||||
self.node_id,
|
||||
register,
|
||||
"Received re-attach response with {} tenants",
|
||||
response.tenants.len()
|
||||
);
|
||||
|
||||
failpoint_support::sleep_millis_async!("control-plane-client-re-attach");
|
||||
|
||||
@@ -279,10 +279,7 @@ impl From<TenantStateError> for ApiError {
|
||||
impl From<GetTenantError> for ApiError {
|
||||
fn from(tse: GetTenantError) -> ApiError {
|
||||
match tse {
|
||||
GetTenantError::NotFound(tid) => ApiError::NotFound(anyhow!("tenant {tid}").into()),
|
||||
GetTenantError::ShardNotFound(tid) => {
|
||||
ApiError::NotFound(anyhow!("tenant {tid}").into())
|
||||
}
|
||||
GetTenantError::NotFound(tid) => ApiError::NotFound(anyhow!("tenant {}", tid).into()),
|
||||
GetTenantError::NotActive(_) => {
|
||||
// Why is this not `ApiError::NotFound`?
|
||||
// Because we must be careful to never return 404 for a tenant if it does
|
||||
@@ -390,16 +387,6 @@ impl From<crate::tenant::mgr::DeleteTenantError> for ApiError {
|
||||
}
|
||||
}
|
||||
|
||||
impl From<crate::tenant::secondary::SecondaryTenantError> for ApiError {
|
||||
fn from(ste: crate::tenant::secondary::SecondaryTenantError) -> ApiError {
|
||||
use crate::tenant::secondary::SecondaryTenantError;
|
||||
match ste {
|
||||
SecondaryTenantError::GetTenant(gte) => gte.into(),
|
||||
SecondaryTenantError::ShuttingDown => ApiError::ShuttingDown,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Helper function to construct a TimelineInfo struct for a timeline
|
||||
async fn build_timeline_info(
|
||||
timeline: &Arc<Timeline>,
|
||||
@@ -1060,11 +1047,9 @@ async fn timeline_delete_handler(
|
||||
match e {
|
||||
// GetTenantError has a built-in conversion to ApiError, but in this context we don't
|
||||
// want to treat missing tenants as 404, to avoid ambiguity with successful deletions.
|
||||
GetTenantError::NotFound(_) | GetTenantError::ShardNotFound(_) => {
|
||||
ApiError::PreconditionFailed(
|
||||
"Requested tenant is missing".to_string().into_boxed_str(),
|
||||
)
|
||||
}
|
||||
GetTenantError::NotFound(_) => ApiError::PreconditionFailed(
|
||||
"Requested tenant is missing".to_string().into_boxed_str(),
|
||||
),
|
||||
e => e.into(),
|
||||
}
|
||||
})?;
|
||||
@@ -2477,7 +2462,8 @@ async fn secondary_upload_handler(
|
||||
state
|
||||
.secondary_controller
|
||||
.upload_tenant(tenant_shard_id)
|
||||
.await?;
|
||||
.await
|
||||
.map_err(ApiError::InternalServerError)?;
|
||||
|
||||
json_response(StatusCode::OK, ())
|
||||
}
|
||||
@@ -2592,7 +2578,7 @@ async fn secondary_download_handler(
|
||||
// Edge case: downloads aren't usually fallible: things like a missing heatmap are considered
|
||||
// okay. We could get an error here in the unlikely edge case that the tenant
|
||||
// was detached between our check above and executing the download job.
|
||||
Ok(Err(e)) => return Err(e.into()),
|
||||
Ok(Err(e)) => return Err(ApiError::InternalServerError(e)),
|
||||
// A timeout is not an error: we have started the download, we're just not done
|
||||
// yet. The caller will get a response body indicating status.
|
||||
Err(_) => StatusCode::ACCEPTED,
|
||||
|
||||
@@ -575,24 +575,18 @@ async fn import_file(
|
||||
} else if file_path.starts_with("pg_xact") {
|
||||
let slru = SlruKind::Clog;
|
||||
|
||||
if modification.tline.tenant_shard_id.is_shard_zero() {
|
||||
import_slru(modification, slru, file_path, reader, len, ctx).await?;
|
||||
debug!("imported clog slru");
|
||||
}
|
||||
import_slru(modification, slru, file_path, reader, len, ctx).await?;
|
||||
debug!("imported clog slru");
|
||||
} else if file_path.starts_with("pg_multixact/offsets") {
|
||||
let slru = SlruKind::MultiXactOffsets;
|
||||
|
||||
if modification.tline.tenant_shard_id.is_shard_zero() {
|
||||
import_slru(modification, slru, file_path, reader, len, ctx).await?;
|
||||
debug!("imported multixact offsets slru");
|
||||
}
|
||||
import_slru(modification, slru, file_path, reader, len, ctx).await?;
|
||||
debug!("imported multixact offsets slru");
|
||||
} else if file_path.starts_with("pg_multixact/members") {
|
||||
let slru = SlruKind::MultiXactMembers;
|
||||
|
||||
if modification.tline.tenant_shard_id.is_shard_zero() {
|
||||
import_slru(modification, slru, file_path, reader, len, ctx).await?;
|
||||
debug!("imported multixact members slru");
|
||||
}
|
||||
import_slru(modification, slru, file_path, reader, len, ctx).await?;
|
||||
debug!("imported multixact members slru");
|
||||
} else if file_path.starts_with("pg_twophase") {
|
||||
let bytes = read_all_bytes(reader).await?;
|
||||
|
||||
|
||||
@@ -217,16 +217,31 @@ impl<'a> ScanLatencyOngoingRecording<'a> {
|
||||
ScanLatencyOngoingRecording { parent, start }
|
||||
}
|
||||
|
||||
pub(crate) fn observe(self) {
|
||||
pub(crate) fn observe(self, throttled: Option<Duration>) {
|
||||
let elapsed = self.start.elapsed();
|
||||
self.parent.observe(elapsed.as_secs_f64());
|
||||
let ex_throttled = if let Some(throttled) = throttled {
|
||||
elapsed.checked_sub(throttled)
|
||||
} else {
|
||||
Some(elapsed)
|
||||
};
|
||||
if let Some(ex_throttled) = ex_throttled {
|
||||
self.parent.observe(ex_throttled.as_secs_f64());
|
||||
} else {
|
||||
use utils::rate_limit::RateLimit;
|
||||
static LOGGED: Lazy<Mutex<RateLimit>> =
|
||||
Lazy::new(|| Mutex::new(RateLimit::new(Duration::from_secs(10))));
|
||||
let mut rate_limit = LOGGED.lock().unwrap();
|
||||
rate_limit.call(|| {
|
||||
warn!("error deducting time spent throttled; this message is logged at a global rate limit");
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) static GET_VECTORED_LATENCY: Lazy<GetVectoredLatency> = Lazy::new(|| {
|
||||
let inner = register_histogram_vec!(
|
||||
"pageserver_get_vectored_seconds",
|
||||
"Time spent in get_vectored.",
|
||||
"Time spent in get_vectored, excluding time spent in timeline_get_throttle.",
|
||||
&["task_kind"],
|
||||
CRITICAL_OP_BUCKETS.into(),
|
||||
)
|
||||
@@ -249,7 +264,7 @@ pub(crate) static GET_VECTORED_LATENCY: Lazy<GetVectoredLatency> = Lazy::new(||
|
||||
pub(crate) static SCAN_LATENCY: Lazy<ScanLatency> = Lazy::new(|| {
|
||||
let inner = register_histogram_vec!(
|
||||
"pageserver_scan_seconds",
|
||||
"Time spent in scan.",
|
||||
"Time spent in scan, excluding time spent in timeline_get_throttle.",
|
||||
&["task_kind"],
|
||||
CRITICAL_OP_BUCKETS.into(),
|
||||
)
|
||||
@@ -1212,44 +1227,11 @@ pub(crate) struct SmgrOpTimer {
|
||||
per_timeline_latency_histo: Option<Histogram>,
|
||||
|
||||
start: Instant,
|
||||
throttled: Duration,
|
||||
op: SmgrQueryType,
|
||||
}
|
||||
|
||||
impl SmgrOpTimer {
|
||||
pub(crate) fn deduct_throttle(&mut self, throttle: &Option<Duration>) {
|
||||
let Some(throttle) = throttle else {
|
||||
return;
|
||||
};
|
||||
self.throttled += *throttle;
|
||||
}
|
||||
}
|
||||
|
||||
impl Drop for SmgrOpTimer {
|
||||
fn drop(&mut self) {
|
||||
let elapsed = self.start.elapsed();
|
||||
|
||||
let elapsed = match elapsed.checked_sub(self.throttled) {
|
||||
Some(elapsed) => elapsed,
|
||||
None => {
|
||||
use utils::rate_limit::RateLimit;
|
||||
static LOGGED: Lazy<Mutex<enum_map::EnumMap<SmgrQueryType, RateLimit>>> =
|
||||
Lazy::new(|| {
|
||||
Mutex::new(enum_map::EnumMap::from_array(std::array::from_fn(|_| {
|
||||
RateLimit::new(Duration::from_secs(10))
|
||||
})))
|
||||
});
|
||||
let mut guard = LOGGED.lock().unwrap();
|
||||
let rate_limit = &mut guard[self.op];
|
||||
rate_limit.call(|| {
|
||||
warn!(op=?self.op, ?elapsed, ?self.throttled, "implementation error: time spent throttled exceeds total request wall clock time");
|
||||
});
|
||||
elapsed // un-throttled time, more info than just saturating to 0
|
||||
}
|
||||
};
|
||||
|
||||
let elapsed = elapsed.as_secs_f64();
|
||||
|
||||
let elapsed = self.start.elapsed().as_secs_f64();
|
||||
self.global_latency_histo.observe(elapsed);
|
||||
if let Some(per_timeline_getpage_histo) = &self.per_timeline_latency_histo {
|
||||
per_timeline_getpage_histo.observe(elapsed);
|
||||
@@ -1509,8 +1491,6 @@ impl SmgrQueryTimePerTimeline {
|
||||
global_latency_histo: self.global_latency[op as usize].clone(),
|
||||
per_timeline_latency_histo,
|
||||
start: started_at,
|
||||
op,
|
||||
throttled: Duration::ZERO,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -3319,7 +3299,7 @@ pub(crate) mod tenant_throttling {
|
||||
use once_cell::sync::Lazy;
|
||||
use utils::shard::TenantShardId;
|
||||
|
||||
use crate::tenant::{self};
|
||||
use crate::tenant::{self, throttle::Metric};
|
||||
|
||||
struct GlobalAndPerTenantIntCounter {
|
||||
global: IntCounter,
|
||||
@@ -3338,7 +3318,7 @@ pub(crate) mod tenant_throttling {
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) struct Metrics<const KIND: usize> {
|
||||
pub(crate) struct TimelineGet {
|
||||
count_accounted_start: GlobalAndPerTenantIntCounter,
|
||||
count_accounted_finish: GlobalAndPerTenantIntCounter,
|
||||
wait_time: GlobalAndPerTenantIntCounter,
|
||||
@@ -3411,41 +3391,40 @@ pub(crate) mod tenant_throttling {
|
||||
.unwrap()
|
||||
});
|
||||
|
||||
const KINDS: &[&str] = &["pagestream"];
|
||||
pub type Pagestream = Metrics<0>;
|
||||
const KIND: &str = "timeline_get";
|
||||
|
||||
impl<const KIND: usize> Metrics<KIND> {
|
||||
impl TimelineGet {
|
||||
pub(crate) fn new(tenant_shard_id: &TenantShardId) -> Self {
|
||||
let per_tenant_label_values = &[
|
||||
KINDS[KIND],
|
||||
KIND,
|
||||
&tenant_shard_id.tenant_id.to_string(),
|
||||
&tenant_shard_id.shard_slug().to_string(),
|
||||
];
|
||||
Metrics {
|
||||
TimelineGet {
|
||||
count_accounted_start: {
|
||||
GlobalAndPerTenantIntCounter {
|
||||
global: COUNT_ACCOUNTED_START.with_label_values(&[KINDS[KIND]]),
|
||||
global: COUNT_ACCOUNTED_START.with_label_values(&[KIND]),
|
||||
per_tenant: COUNT_ACCOUNTED_START_PER_TENANT
|
||||
.with_label_values(per_tenant_label_values),
|
||||
}
|
||||
},
|
||||
count_accounted_finish: {
|
||||
GlobalAndPerTenantIntCounter {
|
||||
global: COUNT_ACCOUNTED_FINISH.with_label_values(&[KINDS[KIND]]),
|
||||
global: COUNT_ACCOUNTED_FINISH.with_label_values(&[KIND]),
|
||||
per_tenant: COUNT_ACCOUNTED_FINISH_PER_TENANT
|
||||
.with_label_values(per_tenant_label_values),
|
||||
}
|
||||
},
|
||||
wait_time: {
|
||||
GlobalAndPerTenantIntCounter {
|
||||
global: WAIT_USECS.with_label_values(&[KINDS[KIND]]),
|
||||
global: WAIT_USECS.with_label_values(&[KIND]),
|
||||
per_tenant: WAIT_USECS_PER_TENANT
|
||||
.with_label_values(per_tenant_label_values),
|
||||
}
|
||||
},
|
||||
count_throttled: {
|
||||
GlobalAndPerTenantIntCounter {
|
||||
global: WAIT_COUNT.with_label_values(&[KINDS[KIND]]),
|
||||
global: WAIT_COUNT.with_label_values(&[KIND]),
|
||||
per_tenant: WAIT_COUNT_PER_TENANT
|
||||
.with_label_values(per_tenant_label_values),
|
||||
}
|
||||
@@ -3468,17 +3447,15 @@ pub(crate) mod tenant_throttling {
|
||||
&WAIT_USECS_PER_TENANT,
|
||||
&WAIT_COUNT_PER_TENANT,
|
||||
] {
|
||||
for kind in KINDS {
|
||||
let _ = m.remove_label_values(&[
|
||||
kind,
|
||||
&tenant_shard_id.tenant_id.to_string(),
|
||||
&tenant_shard_id.shard_slug().to_string(),
|
||||
]);
|
||||
}
|
||||
let _ = m.remove_label_values(&[
|
||||
KIND,
|
||||
&tenant_shard_id.tenant_id.to_string(),
|
||||
&tenant_shard_id.shard_slug().to_string(),
|
||||
]);
|
||||
}
|
||||
}
|
||||
|
||||
impl<const KIND: usize> tenant::throttle::Metric for Metrics<KIND> {
|
||||
impl Metric for TimelineGet {
|
||||
#[inline(always)]
|
||||
fn accounting_start(&self) {
|
||||
self.count_accounted_start.inc();
|
||||
|
||||
@@ -574,41 +574,6 @@ enum BatchedFeMessage {
|
||||
},
|
||||
}
|
||||
|
||||
impl BatchedFeMessage {
|
||||
async fn throttle(&mut self, cancel: &CancellationToken) -> Result<(), QueryError> {
|
||||
let (shard, tokens, timers) = match self {
|
||||
BatchedFeMessage::Exists { shard, timer, .. }
|
||||
| BatchedFeMessage::Nblocks { shard, timer, .. }
|
||||
| BatchedFeMessage::DbSize { shard, timer, .. }
|
||||
| BatchedFeMessage::GetSlruSegment { shard, timer, .. } => {
|
||||
(
|
||||
shard,
|
||||
// 1 token is probably under-estimating because these
|
||||
// request handlers typically do several Timeline::get calls.
|
||||
1,
|
||||
itertools::Either::Left(std::iter::once(timer)),
|
||||
)
|
||||
}
|
||||
BatchedFeMessage::GetPage { shard, pages, .. } => (
|
||||
shard,
|
||||
pages.len(),
|
||||
itertools::Either::Right(pages.iter_mut().map(|(_, _, timer)| timer)),
|
||||
),
|
||||
BatchedFeMessage::RespondError { .. } => return Ok(()),
|
||||
};
|
||||
let throttled = tokio::select! {
|
||||
throttled = shard.pagestream_throttle.throttle(tokens) => { throttled }
|
||||
_ = cancel.cancelled() => {
|
||||
return Err(QueryError::Shutdown);
|
||||
}
|
||||
};
|
||||
for timer in timers {
|
||||
timer.deduct_throttle(&throttled);
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
impl PageServerHandler {
|
||||
pub fn new(
|
||||
tenant_manager: Arc<TenantManager>,
|
||||
@@ -1192,18 +1157,13 @@ impl PageServerHandler {
|
||||
Ok(msg) => msg,
|
||||
Err(e) => break e,
|
||||
};
|
||||
let mut msg = match msg {
|
||||
let msg = match msg {
|
||||
Some(msg) => msg,
|
||||
None => {
|
||||
debug!("pagestream subprotocol end observed");
|
||||
return ((pgb_reader, timeline_handles), Ok(()));
|
||||
}
|
||||
};
|
||||
|
||||
if let Err(cancelled) = msg.throttle(&self.cancel).await {
|
||||
break cancelled;
|
||||
}
|
||||
|
||||
let err = self
|
||||
.pagesteam_handle_batched_message(pgb_writer, msg, &cancel, ctx)
|
||||
.await;
|
||||
@@ -1361,13 +1321,12 @@ impl PageServerHandler {
|
||||
return Ok(());
|
||||
}
|
||||
};
|
||||
let mut batch = match batch {
|
||||
let batch = match batch {
|
||||
Ok(batch) => batch,
|
||||
Err(e) => {
|
||||
return Err(e);
|
||||
}
|
||||
};
|
||||
batch.throttle(&self.cancel).await?;
|
||||
self.pagesteam_handle_batched_message(pgb_writer, batch, &cancel, &ctx)
|
||||
.await?;
|
||||
}
|
||||
|
||||
@@ -530,7 +530,6 @@ impl Timeline {
|
||||
lsn: Lsn,
|
||||
ctx: &RequestContext,
|
||||
) -> Result<Bytes, PageReconstructError> {
|
||||
assert!(self.tenant_shard_id.is_shard_zero());
|
||||
let n_blocks = self
|
||||
.get_slru_segment_size(kind, segno, Version::Lsn(lsn), ctx)
|
||||
.await?;
|
||||
@@ -553,7 +552,6 @@ impl Timeline {
|
||||
lsn: Lsn,
|
||||
ctx: &RequestContext,
|
||||
) -> Result<Bytes, PageReconstructError> {
|
||||
assert!(self.tenant_shard_id.is_shard_zero());
|
||||
let key = slru_block_to_key(kind, segno, blknum);
|
||||
self.get(key, lsn, ctx).await
|
||||
}
|
||||
@@ -566,7 +564,6 @@ impl Timeline {
|
||||
version: Version<'_>,
|
||||
ctx: &RequestContext,
|
||||
) -> Result<BlockNumber, PageReconstructError> {
|
||||
assert!(self.tenant_shard_id.is_shard_zero());
|
||||
let key = slru_segment_size_to_key(kind, segno);
|
||||
let mut buf = version.get(self, key, ctx).await?;
|
||||
Ok(buf.get_u32_le())
|
||||
@@ -580,7 +577,6 @@ impl Timeline {
|
||||
version: Version<'_>,
|
||||
ctx: &RequestContext,
|
||||
) -> Result<bool, PageReconstructError> {
|
||||
assert!(self.tenant_shard_id.is_shard_zero());
|
||||
// fetch directory listing
|
||||
let key = slru_dir_to_key(kind);
|
||||
let buf = version.get(self, key, ctx).await?;
|
||||
@@ -1051,28 +1047,26 @@ impl Timeline {
|
||||
}
|
||||
|
||||
// Iterate SLRUs next
|
||||
if self.tenant_shard_id.is_shard_zero() {
|
||||
for kind in [
|
||||
SlruKind::Clog,
|
||||
SlruKind::MultiXactMembers,
|
||||
SlruKind::MultiXactOffsets,
|
||||
] {
|
||||
let slrudir_key = slru_dir_to_key(kind);
|
||||
result.add_key(slrudir_key);
|
||||
let buf = self.get(slrudir_key, lsn, ctx).await?;
|
||||
let dir = SlruSegmentDirectory::des(&buf)?;
|
||||
let mut segments: Vec<u32> = dir.segments.iter().cloned().collect();
|
||||
segments.sort_unstable();
|
||||
for segno in segments {
|
||||
let segsize_key = slru_segment_size_to_key(kind, segno);
|
||||
let mut buf = self.get(segsize_key, lsn, ctx).await?;
|
||||
let segsize = buf.get_u32_le();
|
||||
for kind in [
|
||||
SlruKind::Clog,
|
||||
SlruKind::MultiXactMembers,
|
||||
SlruKind::MultiXactOffsets,
|
||||
] {
|
||||
let slrudir_key = slru_dir_to_key(kind);
|
||||
result.add_key(slrudir_key);
|
||||
let buf = self.get(slrudir_key, lsn, ctx).await?;
|
||||
let dir = SlruSegmentDirectory::des(&buf)?;
|
||||
let mut segments: Vec<u32> = dir.segments.iter().cloned().collect();
|
||||
segments.sort_unstable();
|
||||
for segno in segments {
|
||||
let segsize_key = slru_segment_size_to_key(kind, segno);
|
||||
let mut buf = self.get(segsize_key, lsn, ctx).await?;
|
||||
let segsize = buf.get_u32_le();
|
||||
|
||||
result.add_range(
|
||||
slru_block_to_key(kind, segno, 0)..slru_block_to_key(kind, segno, segsize),
|
||||
);
|
||||
result.add_key(segsize_key);
|
||||
}
|
||||
result.add_range(
|
||||
slru_block_to_key(kind, segno, 0)..slru_block_to_key(kind, segno, segsize),
|
||||
);
|
||||
result.add_key(segsize_key);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1474,10 +1468,6 @@ impl<'a> DatadirModification<'a> {
|
||||
blknum: BlockNumber,
|
||||
rec: NeonWalRecord,
|
||||
) -> anyhow::Result<()> {
|
||||
if !self.tline.tenant_shard_id.is_shard_zero() {
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
self.put(
|
||||
slru_block_to_key(kind, segno, blknum),
|
||||
Value::WalRecord(rec),
|
||||
@@ -1511,8 +1501,6 @@ impl<'a> DatadirModification<'a> {
|
||||
blknum: BlockNumber,
|
||||
img: Bytes,
|
||||
) -> anyhow::Result<()> {
|
||||
assert!(self.tline.tenant_shard_id.is_shard_zero());
|
||||
|
||||
let key = slru_block_to_key(kind, segno, blknum);
|
||||
if !key.is_valid_key_on_write_path() {
|
||||
anyhow::bail!(
|
||||
@@ -1554,7 +1542,6 @@ impl<'a> DatadirModification<'a> {
|
||||
segno: u32,
|
||||
blknum: BlockNumber,
|
||||
) -> anyhow::Result<()> {
|
||||
assert!(self.tline.tenant_shard_id.is_shard_zero());
|
||||
let key = slru_block_to_key(kind, segno, blknum);
|
||||
if !key.is_valid_key_on_write_path() {
|
||||
anyhow::bail!(
|
||||
@@ -1866,8 +1853,6 @@ impl<'a> DatadirModification<'a> {
|
||||
nblocks: BlockNumber,
|
||||
ctx: &RequestContext,
|
||||
) -> anyhow::Result<()> {
|
||||
assert!(self.tline.tenant_shard_id.is_shard_zero());
|
||||
|
||||
// Add it to the directory entry
|
||||
let dir_key = slru_dir_to_key(kind);
|
||||
let buf = self.get(dir_key, ctx).await?;
|
||||
@@ -1900,8 +1885,6 @@ impl<'a> DatadirModification<'a> {
|
||||
segno: u32,
|
||||
nblocks: BlockNumber,
|
||||
) -> anyhow::Result<()> {
|
||||
assert!(self.tline.tenant_shard_id.is_shard_zero());
|
||||
|
||||
// Put size
|
||||
let size_key = slru_segment_size_to_key(kind, segno);
|
||||
let buf = nblocks.to_le_bytes();
|
||||
|
||||
@@ -357,8 +357,8 @@ pub struct Tenant {
|
||||
|
||||
/// Throttle applied at the top of [`Timeline::get`].
|
||||
/// All [`Tenant::timelines`] of a given [`Tenant`] instance share the same [`throttle::Throttle`] instance.
|
||||
pub(crate) pagestream_throttle:
|
||||
Arc<throttle::Throttle<crate::metrics::tenant_throttling::Pagestream>>,
|
||||
pub(crate) timeline_get_throttle:
|
||||
Arc<throttle::Throttle<crate::metrics::tenant_throttling::TimelineGet>>,
|
||||
|
||||
/// An ongoing timeline detach concurrency limiter.
|
||||
///
|
||||
@@ -1678,7 +1678,7 @@ impl Tenant {
|
||||
remote_metadata,
|
||||
TimelineResources {
|
||||
remote_client,
|
||||
pagestream_throttle: self.pagestream_throttle.clone(),
|
||||
timeline_get_throttle: self.timeline_get_throttle.clone(),
|
||||
l0_flush_global_state: self.l0_flush_global_state.clone(),
|
||||
},
|
||||
LoadTimelineCause::Attach,
|
||||
@@ -3422,7 +3422,7 @@ impl Tenant {
|
||||
r.map_err(
|
||||
|_e: tokio::sync::watch::error::RecvError|
|
||||
// Tenant existed but was dropped: report it as non-existent
|
||||
GetActiveTenantError::NotFound(GetTenantError::ShardNotFound(self.tenant_shard_id))
|
||||
GetActiveTenantError::NotFound(GetTenantError::NotFound(self.tenant_shard_id.tenant_id))
|
||||
)?
|
||||
}
|
||||
Err(TimeoutCancellableError::Cancelled) => {
|
||||
@@ -3835,7 +3835,7 @@ impl Tenant {
|
||||
}
|
||||
}
|
||||
|
||||
fn get_pagestream_throttle_config(
|
||||
fn get_timeline_get_throttle_config(
|
||||
psconf: &'static PageServerConf,
|
||||
overrides: &TenantConfOpt,
|
||||
) -> throttle::Config {
|
||||
@@ -3846,8 +3846,8 @@ impl Tenant {
|
||||
}
|
||||
|
||||
pub(crate) fn tenant_conf_updated(&self, new_conf: &TenantConfOpt) {
|
||||
let conf = Self::get_pagestream_throttle_config(self.conf, new_conf);
|
||||
self.pagestream_throttle.reconfigure(conf)
|
||||
let conf = Self::get_timeline_get_throttle_config(self.conf, new_conf);
|
||||
self.timeline_get_throttle.reconfigure(conf)
|
||||
}
|
||||
|
||||
/// Helper function to create a new Timeline struct.
|
||||
@@ -4009,9 +4009,9 @@ impl Tenant {
|
||||
attach_wal_lag_cooldown: Arc::new(std::sync::OnceLock::new()),
|
||||
cancel: CancellationToken::default(),
|
||||
gate: Gate::default(),
|
||||
pagestream_throttle: Arc::new(throttle::Throttle::new(
|
||||
Tenant::get_pagestream_throttle_config(conf, &attached_conf.tenant_conf),
|
||||
crate::metrics::tenant_throttling::Metrics::new(&tenant_shard_id),
|
||||
timeline_get_throttle: Arc::new(throttle::Throttle::new(
|
||||
Tenant::get_timeline_get_throttle_config(conf, &attached_conf.tenant_conf),
|
||||
crate::metrics::tenant_throttling::TimelineGet::new(&tenant_shard_id),
|
||||
)),
|
||||
tenant_conf: Arc::new(ArcSwap::from_pointee(attached_conf)),
|
||||
ongoing_timeline_detach: std::sync::Mutex::default(),
|
||||
@@ -4909,7 +4909,7 @@ impl Tenant {
|
||||
fn build_timeline_resources(&self, timeline_id: TimelineId) -> TimelineResources {
|
||||
TimelineResources {
|
||||
remote_client: self.build_timeline_remote_client(timeline_id),
|
||||
pagestream_throttle: self.pagestream_throttle.clone(),
|
||||
timeline_get_throttle: self.timeline_get_throttle.clone(),
|
||||
l0_flush_global_state: self.l0_flush_global_state.clone(),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -8,8 +8,10 @@ use crate::page_cache;
|
||||
use crate::tenant::storage_layer::inmemory_layer::vectored_dio_read::File;
|
||||
use crate::virtual_file::owned_buffers_io::io_buf_aligned::IoBufAlignedMut;
|
||||
use crate::virtual_file::owned_buffers_io::slice::SliceMutExt;
|
||||
use crate::virtual_file::owned_buffers_io::util::size_tracking_writer;
|
||||
use crate::virtual_file::owned_buffers_io::write::Buffer;
|
||||
use crate::virtual_file::{self, owned_buffers_io, IoBufferMut, VirtualFile};
|
||||
use bytes::BytesMut;
|
||||
use camino::Utf8PathBuf;
|
||||
use num_traits::Num;
|
||||
use pageserver_api::shard::TenantShardId;
|
||||
@@ -18,7 +20,6 @@ use tracing::error;
|
||||
|
||||
use std::io;
|
||||
use std::sync::atomic::AtomicU64;
|
||||
use std::sync::Arc;
|
||||
use utils::id::TimelineId;
|
||||
|
||||
pub struct EphemeralFile {
|
||||
@@ -26,7 +27,10 @@ pub struct EphemeralFile {
|
||||
_timeline_id: TimelineId,
|
||||
page_cache_file_id: page_cache::FileId,
|
||||
bytes_written: u64,
|
||||
buffered_writer: owned_buffers_io::write::BufferedWriter<IoBufferMut, VirtualFile>,
|
||||
buffered_writer: owned_buffers_io::write::BufferedWriter<
|
||||
BytesMut,
|
||||
size_tracking_writer::Writer<VirtualFile>,
|
||||
>,
|
||||
/// Gate guard is held on as long as we need to do operations in the path (delete on drop)
|
||||
_gate_guard: utils::sync::gate::GateGuard,
|
||||
}
|
||||
@@ -38,9 +42,9 @@ impl EphemeralFile {
|
||||
conf: &PageServerConf,
|
||||
tenant_shard_id: TenantShardId,
|
||||
timeline_id: TimelineId,
|
||||
gate: &utils::sync::gate::Gate,
|
||||
gate_guard: utils::sync::gate::GateGuard,
|
||||
ctx: &RequestContext,
|
||||
) -> anyhow::Result<EphemeralFile> {
|
||||
) -> Result<EphemeralFile, io::Error> {
|
||||
static NEXT_FILENAME: AtomicU64 = AtomicU64::new(1);
|
||||
let filename_disambiguator =
|
||||
NEXT_FILENAME.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
|
||||
@@ -51,17 +55,15 @@ impl EphemeralFile {
|
||||
"ephemeral-{filename_disambiguator}"
|
||||
)));
|
||||
|
||||
let file = Arc::new(
|
||||
VirtualFile::open_with_options_v2(
|
||||
&filename,
|
||||
virtual_file::OpenOptions::new()
|
||||
.read(true)
|
||||
.write(true)
|
||||
.create(true),
|
||||
ctx,
|
||||
)
|
||||
.await?,
|
||||
);
|
||||
let file = VirtualFile::open_with_options(
|
||||
&filename,
|
||||
virtual_file::OpenOptions::new()
|
||||
.read(true)
|
||||
.write(true)
|
||||
.create(true),
|
||||
ctx,
|
||||
)
|
||||
.await?;
|
||||
|
||||
let page_cache_file_id = page_cache::next_file_id(); // XXX get rid, we're not page-caching anymore
|
||||
|
||||
@@ -71,12 +73,10 @@ impl EphemeralFile {
|
||||
page_cache_file_id,
|
||||
bytes_written: 0,
|
||||
buffered_writer: owned_buffers_io::write::BufferedWriter::new(
|
||||
file,
|
||||
|| IoBufferMut::with_capacity(TAIL_SZ),
|
||||
gate.enter()?,
|
||||
ctx,
|
||||
size_tracking_writer::Writer::new(file),
|
||||
BytesMut::with_capacity(TAIL_SZ),
|
||||
),
|
||||
_gate_guard: gate.enter()?,
|
||||
_gate_guard: gate_guard,
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -85,7 +85,7 @@ impl Drop for EphemeralFile {
|
||||
fn drop(&mut self) {
|
||||
// unlink the file
|
||||
// we are clear to do this, because we have entered a gate
|
||||
let path = self.buffered_writer.as_inner().path();
|
||||
let path = self.buffered_writer.as_inner().as_inner().path();
|
||||
let res = std::fs::remove_file(path);
|
||||
if let Err(e) = res {
|
||||
if e.kind() != std::io::ErrorKind::NotFound {
|
||||
@@ -132,18 +132,6 @@ impl EphemeralFile {
|
||||
srcbuf: &[u8],
|
||||
ctx: &RequestContext,
|
||||
) -> std::io::Result<u64> {
|
||||
let (pos, control) = self.write_raw_controlled(srcbuf, ctx).await?;
|
||||
if let Some(control) = control {
|
||||
control.release().await;
|
||||
}
|
||||
Ok(pos)
|
||||
}
|
||||
|
||||
async fn write_raw_controlled(
|
||||
&mut self,
|
||||
srcbuf: &[u8],
|
||||
ctx: &RequestContext,
|
||||
) -> std::io::Result<(u64, Option<owned_buffers_io::write::FlushControl>)> {
|
||||
let pos = self.bytes_written;
|
||||
|
||||
let new_bytes_written = pos.checked_add(srcbuf.len().into_u64()).ok_or_else(|| {
|
||||
@@ -157,9 +145,9 @@ impl EphemeralFile {
|
||||
})?;
|
||||
|
||||
// Write the payload
|
||||
let (nwritten, control) = self
|
||||
let nwritten = self
|
||||
.buffered_writer
|
||||
.write_buffered_borrowed_controlled(srcbuf, ctx)
|
||||
.write_buffered_borrowed(srcbuf, ctx)
|
||||
.await?;
|
||||
assert_eq!(
|
||||
nwritten,
|
||||
@@ -169,7 +157,7 @@ impl EphemeralFile {
|
||||
|
||||
self.bytes_written = new_bytes_written;
|
||||
|
||||
Ok((pos, control))
|
||||
Ok(pos)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -180,12 +168,11 @@ impl super::storage_layer::inmemory_layer::vectored_dio_read::File for Ephemeral
|
||||
dst: tokio_epoll_uring::Slice<B>,
|
||||
ctx: &'a RequestContext,
|
||||
) -> std::io::Result<(tokio_epoll_uring::Slice<B>, usize)> {
|
||||
let submitted_offset = self.buffered_writer.bytes_submitted();
|
||||
let file_size_tracking_writer = self.buffered_writer.as_inner();
|
||||
let flushed_offset = file_size_tracking_writer.bytes_written();
|
||||
|
||||
let mutable = self.buffered_writer.inspect_mutable();
|
||||
let mutable = &mutable[0..mutable.pending()];
|
||||
|
||||
let maybe_flushed = self.buffered_writer.inspect_maybe_flushed();
|
||||
let buffer = self.buffered_writer.inspect_buffer();
|
||||
let buffered = &buffer[0..buffer.pending()];
|
||||
|
||||
let dst_cap = dst.bytes_total().into_u64();
|
||||
let end = {
|
||||
@@ -210,42 +197,11 @@ impl super::storage_layer::inmemory_layer::vectored_dio_read::File for Ephemeral
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
let (written_range, maybe_flushed_range) = {
|
||||
if maybe_flushed.is_some() {
|
||||
// [ written ][ maybe_flushed ][ mutable ]
|
||||
// <- TAIL_SZ -><- TAIL_SZ ->
|
||||
// ^
|
||||
// `submitted_offset`
|
||||
// <++++++ on disk +++++++????????????????>
|
||||
(
|
||||
Range(
|
||||
start,
|
||||
std::cmp::min(end, submitted_offset.saturating_sub(TAIL_SZ as u64)),
|
||||
),
|
||||
Range(
|
||||
std::cmp::max(start, submitted_offset.saturating_sub(TAIL_SZ as u64)),
|
||||
std::cmp::min(end, submitted_offset),
|
||||
),
|
||||
)
|
||||
} else {
|
||||
// [ written ][ mutable ]
|
||||
// <- TAIL_SZ ->
|
||||
// ^
|
||||
// `submitted_offset`
|
||||
// <++++++ on disk +++++++++++++++++++++++>
|
||||
(
|
||||
Range(start, std::cmp::min(end, submitted_offset)),
|
||||
// zero len
|
||||
Range(submitted_offset, u64::MIN),
|
||||
)
|
||||
}
|
||||
};
|
||||
|
||||
let mutable_range = Range(std::cmp::max(start, submitted_offset), end);
|
||||
let written_range = Range(start, std::cmp::min(end, flushed_offset));
|
||||
let buffered_range = Range(std::cmp::max(start, flushed_offset), end);
|
||||
|
||||
let dst = if written_range.len() > 0 {
|
||||
let file: &VirtualFile = self.buffered_writer.as_inner();
|
||||
let file: &VirtualFile = file_size_tracking_writer.as_inner();
|
||||
let bounds = dst.bounds();
|
||||
let slice = file
|
||||
.read_exact_at(dst.slice(0..written_range.len().into_usize()), start, ctx)
|
||||
@@ -255,21 +211,19 @@ impl super::storage_layer::inmemory_layer::vectored_dio_read::File for Ephemeral
|
||||
dst
|
||||
};
|
||||
|
||||
let dst = if maybe_flushed_range.len() > 0 {
|
||||
let offset_in_buffer = maybe_flushed_range
|
||||
let dst = if buffered_range.len() > 0 {
|
||||
let offset_in_buffer = buffered_range
|
||||
.0
|
||||
.checked_sub(submitted_offset.saturating_sub(TAIL_SZ as u64))
|
||||
.checked_sub(flushed_offset)
|
||||
.unwrap()
|
||||
.into_usize();
|
||||
// Checked previously the buffer is Some.
|
||||
let maybe_flushed = maybe_flushed.unwrap();
|
||||
let to_copy = &maybe_flushed
|
||||
[offset_in_buffer..(offset_in_buffer + maybe_flushed_range.len().into_usize())];
|
||||
let to_copy =
|
||||
&buffered[offset_in_buffer..(offset_in_buffer + buffered_range.len().into_usize())];
|
||||
let bounds = dst.bounds();
|
||||
let mut view = dst.slice({
|
||||
let start = written_range.len().into_usize();
|
||||
let end = start
|
||||
.checked_add(maybe_flushed_range.len().into_usize())
|
||||
.checked_add(buffered_range.len().into_usize())
|
||||
.unwrap();
|
||||
start..end
|
||||
});
|
||||
@@ -280,28 +234,6 @@ impl super::storage_layer::inmemory_layer::vectored_dio_read::File for Ephemeral
|
||||
dst
|
||||
};
|
||||
|
||||
let dst = if mutable_range.len() > 0 {
|
||||
let offset_in_buffer = mutable_range
|
||||
.0
|
||||
.checked_sub(submitted_offset)
|
||||
.unwrap()
|
||||
.into_usize();
|
||||
let to_copy =
|
||||
&mutable[offset_in_buffer..(offset_in_buffer + mutable_range.len().into_usize())];
|
||||
let bounds = dst.bounds();
|
||||
let mut view = dst.slice({
|
||||
let start =
|
||||
written_range.len().into_usize() + maybe_flushed_range.len().into_usize();
|
||||
let end = start.checked_add(mutable_range.len().into_usize()).unwrap();
|
||||
start..end
|
||||
});
|
||||
view.as_mut_rust_slice_full_zeroed()
|
||||
.copy_from_slice(to_copy);
|
||||
Slice::from_buf_bounds(Slice::into_inner(view), bounds)
|
||||
} else {
|
||||
dst
|
||||
};
|
||||
|
||||
// TODO: in debug mode, randomize the remaining bytes in `dst` to catch bugs
|
||||
|
||||
Ok((dst, (end - start).into_usize()))
|
||||
@@ -363,7 +295,7 @@ mod tests {
|
||||
|
||||
let gate = utils::sync::gate::Gate::default();
|
||||
|
||||
let file = EphemeralFile::create(conf, tenant_id, timeline_id, &gate, &ctx)
|
||||
let file = EphemeralFile::create(conf, tenant_id, timeline_id, gate.enter().unwrap(), &ctx)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
@@ -394,15 +326,14 @@ mod tests {
|
||||
|
||||
let gate = utils::sync::gate::Gate::default();
|
||||
|
||||
let mut file = EphemeralFile::create(conf, tenant_id, timeline_id, &gate, &ctx)
|
||||
.await
|
||||
.unwrap();
|
||||
let mut file =
|
||||
EphemeralFile::create(conf, tenant_id, timeline_id, gate.enter().unwrap(), &ctx)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
let mutable = file.buffered_writer.inspect_mutable();
|
||||
let cap = mutable.capacity();
|
||||
let align = mutable.align();
|
||||
let cap = file.buffered_writer.inspect_buffer().capacity();
|
||||
|
||||
let write_nbytes = cap * 2 + cap / 2;
|
||||
let write_nbytes = cap + cap / 2;
|
||||
|
||||
let content: Vec<u8> = rand::thread_rng()
|
||||
.sample_iter(rand::distributions::Standard)
|
||||
@@ -410,39 +341,30 @@ mod tests {
|
||||
.collect();
|
||||
|
||||
let mut value_offsets = Vec::new();
|
||||
for range in (0..write_nbytes)
|
||||
.step_by(align)
|
||||
.map(|start| start..(start + align).min(write_nbytes))
|
||||
{
|
||||
let off = file.write_raw(&content[range], &ctx).await.unwrap();
|
||||
for i in 0..write_nbytes {
|
||||
let off = file.write_raw(&content[i..i + 1], &ctx).await.unwrap();
|
||||
value_offsets.push(off);
|
||||
}
|
||||
|
||||
assert_eq!(file.len() as usize, write_nbytes);
|
||||
for (i, range) in (0..write_nbytes)
|
||||
.step_by(align)
|
||||
.map(|start| start..(start + align).min(write_nbytes))
|
||||
.enumerate()
|
||||
{
|
||||
assert_eq!(value_offsets[i], range.start.into_u64());
|
||||
let buf = IoBufferMut::with_capacity(range.len());
|
||||
assert!(file.len() as usize == write_nbytes);
|
||||
for i in 0..write_nbytes {
|
||||
assert_eq!(value_offsets[i], i.into_u64());
|
||||
let buf = IoBufferMut::with_capacity(1);
|
||||
let (buf_slice, nread) = file
|
||||
.read_exact_at_eof_ok(range.start.into_u64(), buf.slice_full(), &ctx)
|
||||
.read_exact_at_eof_ok(i.into_u64(), buf.slice_full(), &ctx)
|
||||
.await
|
||||
.unwrap();
|
||||
let buf = buf_slice.into_inner();
|
||||
assert_eq!(nread, range.len());
|
||||
assert_eq!(&buf, &content[range]);
|
||||
assert_eq!(nread, 1);
|
||||
assert_eq!(&buf, &content[i..i + 1]);
|
||||
}
|
||||
|
||||
let file_contents = std::fs::read(file.buffered_writer.as_inner().path()).unwrap();
|
||||
assert!(file_contents == content[0..cap * 2]);
|
||||
let file_contents =
|
||||
std::fs::read(file.buffered_writer.as_inner().as_inner().path()).unwrap();
|
||||
assert_eq!(file_contents, &content[0..cap]);
|
||||
|
||||
let maybe_flushed_buffer_contents = file.buffered_writer.inspect_maybe_flushed().unwrap();
|
||||
assert_eq!(&maybe_flushed_buffer_contents[..], &content[cap..cap * 2]);
|
||||
|
||||
let mutable_buffer_contents = file.buffered_writer.inspect_mutable();
|
||||
assert_eq!(mutable_buffer_contents, &content[cap * 2..write_nbytes]);
|
||||
let buffer_contents = file.buffered_writer.inspect_buffer();
|
||||
assert_eq!(buffer_contents, &content[cap..write_nbytes]);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
@@ -451,16 +373,16 @@ mod tests {
|
||||
|
||||
let gate = utils::sync::gate::Gate::default();
|
||||
|
||||
let mut file = EphemeralFile::create(conf, tenant_id, timeline_id, &gate, &ctx)
|
||||
.await
|
||||
.unwrap();
|
||||
let mut file =
|
||||
EphemeralFile::create(conf, tenant_id, timeline_id, gate.enter().unwrap(), &ctx)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
// mutable buffer and maybe_flushed buffer each has `cap` bytes.
|
||||
let cap = file.buffered_writer.inspect_mutable().capacity();
|
||||
let cap = file.buffered_writer.inspect_buffer().capacity();
|
||||
|
||||
let content: Vec<u8> = rand::thread_rng()
|
||||
.sample_iter(rand::distributions::Standard)
|
||||
.take(cap * 2 + cap / 2)
|
||||
.take(cap + cap / 2)
|
||||
.collect();
|
||||
|
||||
file.write_raw(&content, &ctx).await.unwrap();
|
||||
@@ -468,21 +390,23 @@ mod tests {
|
||||
// assert the state is as this test expects it to be
|
||||
assert_eq!(
|
||||
&file.load_to_io_buf(&ctx).await.unwrap(),
|
||||
&content[0..cap * 2 + cap / 2]
|
||||
&content[0..cap + cap / 2]
|
||||
);
|
||||
let md = file.buffered_writer.as_inner().path().metadata().unwrap();
|
||||
let md = file
|
||||
.buffered_writer
|
||||
.as_inner()
|
||||
.as_inner()
|
||||
.path()
|
||||
.metadata()
|
||||
.unwrap();
|
||||
assert_eq!(
|
||||
md.len(),
|
||||
2 * cap.into_u64(),
|
||||
"buffered writer requires one write to be flushed if we write 2.5x buffer capacity"
|
||||
cap.into_u64(),
|
||||
"buffered writer does one write if we write 1.5x buffer capacity"
|
||||
);
|
||||
assert_eq!(
|
||||
&file.buffered_writer.inspect_maybe_flushed().unwrap()[0..cap],
|
||||
&content[cap..cap * 2]
|
||||
);
|
||||
assert_eq!(
|
||||
&file.buffered_writer.inspect_mutable()[0..cap / 2],
|
||||
&content[cap * 2..cap * 2 + cap / 2]
|
||||
&file.buffered_writer.inspect_buffer()[0..cap / 2],
|
||||
&content[cap..cap + cap / 2]
|
||||
);
|
||||
}
|
||||
|
||||
@@ -498,19 +422,19 @@ mod tests {
|
||||
|
||||
let gate = utils::sync::gate::Gate::default();
|
||||
|
||||
let mut file = EphemeralFile::create(conf, tenant_id, timeline_id, &gate, &ctx)
|
||||
.await
|
||||
.unwrap();
|
||||
let mut file =
|
||||
EphemeralFile::create(conf, tenant_id, timeline_id, gate.enter().unwrap(), &ctx)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
let cap = file.buffered_writer.inspect_buffer().capacity();
|
||||
|
||||
let mutable = file.buffered_writer.inspect_mutable();
|
||||
let cap = mutable.capacity();
|
||||
let align = mutable.align();
|
||||
let content: Vec<u8> = rand::thread_rng()
|
||||
.sample_iter(rand::distributions::Standard)
|
||||
.take(cap * 2 + cap / 2)
|
||||
.take(cap + cap / 2)
|
||||
.collect();
|
||||
|
||||
let (_, control) = file.write_raw_controlled(&content, &ctx).await.unwrap();
|
||||
file.write_raw(&content, &ctx).await.unwrap();
|
||||
|
||||
let test_read = |start: usize, len: usize| {
|
||||
let file = &file;
|
||||
@@ -530,38 +454,16 @@ mod tests {
|
||||
}
|
||||
};
|
||||
|
||||
let test_read_all_offset_combinations = || {
|
||||
async move {
|
||||
test_read(align, align).await;
|
||||
// border onto edge of file
|
||||
test_read(cap - align, align).await;
|
||||
// read across file and buffer
|
||||
test_read(cap - align, 2 * align).await;
|
||||
// stay from start of maybe flushed buffer
|
||||
test_read(cap, align).await;
|
||||
// completely within maybe flushed buffer
|
||||
test_read(cap + align, align).await;
|
||||
// border onto edge of maybe flushed buffer.
|
||||
test_read(cap * 2 - align, align).await;
|
||||
// read across maybe flushed and mutable buffer
|
||||
test_read(cap * 2 - align, 2 * align).await;
|
||||
// read across three segments
|
||||
test_read(cap - align, cap + 2 * align).await;
|
||||
// completely within mutable buffer
|
||||
test_read(cap * 2 + align, align).await;
|
||||
}
|
||||
};
|
||||
|
||||
// completely within the file range
|
||||
assert!(align < cap, "test assumption");
|
||||
assert!(cap % align == 0);
|
||||
|
||||
// test reads at different flush stages.
|
||||
let not_started = control.unwrap().into_not_started();
|
||||
test_read_all_offset_combinations().await;
|
||||
let in_progress = not_started.ready_to_flush();
|
||||
test_read_all_offset_combinations().await;
|
||||
in_progress.wait_until_flush_is_done().await;
|
||||
test_read_all_offset_combinations().await;
|
||||
assert!(20 < cap, "test assumption");
|
||||
test_read(10, 10).await;
|
||||
// border onto edge of file
|
||||
test_read(cap - 10, 10).await;
|
||||
// read across file and buffer
|
||||
test_read(cap - 10, 20).await;
|
||||
// stay from start of buffer
|
||||
test_read(cap, 10).await;
|
||||
// completely within buffer
|
||||
test_read(cap + 10, 10).await;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -347,7 +347,7 @@ async fn init_load_generations(
|
||||
);
|
||||
emergency_generations(tenant_confs)
|
||||
} else if let Some(client) = ControllerUpcallClient::new(conf, cancel) {
|
||||
info!("Calling {} API to re-attach tenants", client.base_url());
|
||||
info!("Calling control plane API to re-attach tenants");
|
||||
// If we are configured to use the control plane API, then it is the source of truth for what tenants to load.
|
||||
match client.re_attach(conf).await {
|
||||
Ok(tenants) => tenants
|
||||
@@ -894,7 +894,7 @@ impl TenantManager {
|
||||
Some(TenantSlot::Attached(tenant)) => Ok(Arc::clone(tenant)),
|
||||
Some(TenantSlot::InProgress(_)) => Err(GetTenantError::NotActive(tenant_shard_id)),
|
||||
None | Some(TenantSlot::Secondary(_)) => {
|
||||
Err(GetTenantError::ShardNotFound(tenant_shard_id))
|
||||
Err(GetTenantError::NotFound(tenant_shard_id.tenant_id))
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -2258,9 +2258,6 @@ pub(crate) enum GetTenantError {
|
||||
#[error("Tenant {0} not found")]
|
||||
NotFound(TenantId),
|
||||
|
||||
#[error("Tenant {0} not found")]
|
||||
ShardNotFound(TenantShardId),
|
||||
|
||||
#[error("Tenant {0} is not active")]
|
||||
NotActive(TenantShardId),
|
||||
|
||||
|
||||
@@ -681,7 +681,6 @@ impl RemoteTimelineClient {
|
||||
layer_file_name: &LayerName,
|
||||
layer_metadata: &LayerFileMetadata,
|
||||
local_path: &Utf8Path,
|
||||
gate: &utils::sync::gate::Gate,
|
||||
cancel: &CancellationToken,
|
||||
ctx: &RequestContext,
|
||||
) -> Result<u64, DownloadError> {
|
||||
@@ -701,7 +700,6 @@ impl RemoteTimelineClient {
|
||||
layer_file_name,
|
||||
layer_metadata,
|
||||
local_path,
|
||||
gate,
|
||||
cancel,
|
||||
ctx,
|
||||
)
|
||||
@@ -2566,9 +2564,9 @@ pub fn parse_remote_index_path(path: RemotePath) -> Option<Generation> {
|
||||
}
|
||||
|
||||
/// Given the key of a tenant manifest, parse out the generation number
|
||||
pub fn parse_remote_tenant_manifest_path(path: RemotePath) -> Option<Generation> {
|
||||
pub(crate) fn parse_remote_tenant_manifest_path(path: RemotePath) -> Option<Generation> {
|
||||
static RE: OnceLock<Regex> = OnceLock::new();
|
||||
let re = RE.get_or_init(|| Regex::new(r".*tenant-manifest-([0-9a-f]{8}).json").unwrap());
|
||||
let re = RE.get_or_init(|| Regex::new(r".+tenant-manifest-([0-9a-f]{8}).json").unwrap());
|
||||
re.captures(path.get_path().as_str())
|
||||
.and_then(|c| c.get(1))
|
||||
.and_then(|m| Generation::parse_suffix(m.as_str()))
|
||||
|
||||
@@ -6,7 +6,6 @@
|
||||
use std::collections::HashSet;
|
||||
use std::future::Future;
|
||||
use std::str::FromStr;
|
||||
use std::sync::Arc;
|
||||
use std::time::SystemTime;
|
||||
|
||||
use anyhow::{anyhow, Context};
|
||||
@@ -27,7 +26,9 @@ use crate::span::{
|
||||
use crate::tenant::remote_timeline_client::{remote_layer_path, remote_timelines_path};
|
||||
use crate::tenant::storage_layer::LayerName;
|
||||
use crate::tenant::Generation;
|
||||
use crate::virtual_file::{on_fatal_io_error, IoBufferMut, MaybeFatalIo, VirtualFile};
|
||||
#[cfg_attr(target_os = "macos", allow(unused_imports))]
|
||||
use crate::virtual_file::owned_buffers_io::io_buf_ext::IoBufExt;
|
||||
use crate::virtual_file::{on_fatal_io_error, MaybeFatalIo, VirtualFile};
|
||||
use crate::TEMP_FILE_SUFFIX;
|
||||
use remote_storage::{
|
||||
DownloadError, DownloadKind, DownloadOpts, GenericRemoteStorage, ListingMode, RemotePath,
|
||||
@@ -59,7 +60,6 @@ pub async fn download_layer_file<'a>(
|
||||
layer_file_name: &'a LayerName,
|
||||
layer_metadata: &'a LayerFileMetadata,
|
||||
local_path: &Utf8Path,
|
||||
gate: &utils::sync::gate::Gate,
|
||||
cancel: &CancellationToken,
|
||||
ctx: &RequestContext,
|
||||
) -> Result<u64, DownloadError> {
|
||||
@@ -88,9 +88,7 @@ pub async fn download_layer_file<'a>(
|
||||
let temp_file_path = path_with_suffix_extension(local_path, TEMP_DOWNLOAD_EXTENSION);
|
||||
|
||||
let bytes_amount = download_retry(
|
||||
|| async {
|
||||
download_object(storage, &remote_path, &temp_file_path, gate, cancel, ctx).await
|
||||
},
|
||||
|| async { download_object(storage, &remote_path, &temp_file_path, cancel, ctx).await },
|
||||
&format!("download {remote_path:?}"),
|
||||
cancel,
|
||||
)
|
||||
@@ -150,7 +148,6 @@ async fn download_object<'a>(
|
||||
storage: &'a GenericRemoteStorage,
|
||||
src_path: &RemotePath,
|
||||
dst_path: &Utf8PathBuf,
|
||||
gate: &utils::sync::gate::Gate,
|
||||
cancel: &CancellationToken,
|
||||
#[cfg_attr(target_os = "macos", allow(unused_variables))] ctx: &RequestContext,
|
||||
) -> Result<u64, DownloadError> {
|
||||
@@ -208,16 +205,13 @@ async fn download_object<'a>(
|
||||
}
|
||||
#[cfg(target_os = "linux")]
|
||||
crate::virtual_file::io_engine::IoEngine::TokioEpollUring => {
|
||||
use crate::virtual_file::owned_buffers_io;
|
||||
use crate::virtual_file::owned_buffers_io::{self, util::size_tracking_writer};
|
||||
use bytes::BytesMut;
|
||||
async {
|
||||
let destination_file = Arc::new(
|
||||
VirtualFile::create(dst_path, ctx)
|
||||
.await
|
||||
.with_context(|| {
|
||||
format!("create a destination file for layer '{dst_path}'")
|
||||
})
|
||||
.map_err(DownloadError::Other)?,
|
||||
);
|
||||
let destination_file = VirtualFile::create(dst_path, ctx)
|
||||
.await
|
||||
.with_context(|| format!("create a destination file for layer '{dst_path}'"))
|
||||
.map_err(DownloadError::Other)?;
|
||||
|
||||
let mut download = storage
|
||||
.download(src_path, &DownloadOpts::default(), cancel)
|
||||
@@ -225,16 +219,14 @@ async fn download_object<'a>(
|
||||
|
||||
pausable_failpoint!("before-downloading-layer-stream-pausable");
|
||||
|
||||
let mut buffered = owned_buffers_io::write::BufferedWriter::<IoBufferMut, _>::new(
|
||||
destination_file,
|
||||
|| IoBufferMut::with_capacity(super::BUFFER_SIZE),
|
||||
gate.enter().map_err(|_| DownloadError::Cancelled)?,
|
||||
ctx,
|
||||
);
|
||||
|
||||
// TODO: use vectored write (writev) once supported by tokio-epoll-uring.
|
||||
// There's chunks_vectored() on the stream.
|
||||
let (bytes_amount, destination_file) = async {
|
||||
let size_tracking = size_tracking_writer::Writer::new(destination_file);
|
||||
let mut buffered = owned_buffers_io::write::BufferedWriter::<BytesMut, _>::new(
|
||||
size_tracking,
|
||||
BytesMut::with_capacity(super::BUFFER_SIZE),
|
||||
);
|
||||
while let Some(res) =
|
||||
futures::StreamExt::next(&mut download.download_stream).await
|
||||
{
|
||||
@@ -242,10 +234,10 @@ async fn download_object<'a>(
|
||||
Ok(chunk) => chunk,
|
||||
Err(e) => return Err(e),
|
||||
};
|
||||
buffered.write_buffered_borrowed(&chunk, ctx).await?;
|
||||
buffered.write_buffered(chunk.slice_len(), ctx).await?;
|
||||
}
|
||||
let inner = buffered.flush_and_into_inner(ctx).await?;
|
||||
Ok(inner)
|
||||
let size_tracking = buffered.flush_and_into_inner(ctx).await?;
|
||||
Ok(size_tracking.into_inner())
|
||||
}
|
||||
.await?;
|
||||
|
||||
|
||||
@@ -43,7 +43,7 @@ impl TenantManifest {
|
||||
offloaded_timelines: vec![],
|
||||
}
|
||||
}
|
||||
pub fn from_json_bytes(bytes: &[u8]) -> Result<Self, serde_json::Error> {
|
||||
pub(crate) fn from_json_bytes(bytes: &[u8]) -> Result<Self, serde_json::Error> {
|
||||
serde_json::from_slice::<Self>(bytes)
|
||||
}
|
||||
|
||||
|
||||
@@ -22,7 +22,6 @@ use super::{
|
||||
mgr::TenantManager,
|
||||
span::debug_assert_current_span_has_tenant_id,
|
||||
storage_layer::LayerName,
|
||||
GetTenantError,
|
||||
};
|
||||
|
||||
use crate::metrics::SECONDARY_RESIDENT_PHYSICAL_SIZE;
|
||||
@@ -67,21 +66,7 @@ struct CommandRequest<T> {
|
||||
}
|
||||
|
||||
struct CommandResponse {
|
||||
result: Result<(), SecondaryTenantError>,
|
||||
}
|
||||
|
||||
#[derive(thiserror::Error, Debug)]
|
||||
pub(crate) enum SecondaryTenantError {
|
||||
#[error("{0}")]
|
||||
GetTenant(GetTenantError),
|
||||
#[error("shutting down")]
|
||||
ShuttingDown,
|
||||
}
|
||||
|
||||
impl From<GetTenantError> for SecondaryTenantError {
|
||||
fn from(gte: GetTenantError) -> Self {
|
||||
Self::GetTenant(gte)
|
||||
}
|
||||
result: anyhow::Result<()>,
|
||||
}
|
||||
|
||||
// Whereas [`Tenant`] represents an attached tenant, this type represents the work
|
||||
@@ -300,7 +285,7 @@ impl SecondaryController {
|
||||
&self,
|
||||
queue: &tokio::sync::mpsc::Sender<CommandRequest<T>>,
|
||||
payload: T,
|
||||
) -> Result<(), SecondaryTenantError> {
|
||||
) -> anyhow::Result<()> {
|
||||
let (response_tx, response_rx) = tokio::sync::oneshot::channel();
|
||||
|
||||
queue
|
||||
@@ -309,26 +294,20 @@ impl SecondaryController {
|
||||
response_tx,
|
||||
})
|
||||
.await
|
||||
.map_err(|_| SecondaryTenantError::ShuttingDown)?;
|
||||
.map_err(|_| anyhow::anyhow!("Receiver shut down"))?;
|
||||
|
||||
let response = response_rx
|
||||
.await
|
||||
.map_err(|_| SecondaryTenantError::ShuttingDown)?;
|
||||
.map_err(|_| anyhow::anyhow!("Request dropped"))?;
|
||||
|
||||
response.result
|
||||
}
|
||||
|
||||
pub(crate) async fn upload_tenant(
|
||||
&self,
|
||||
tenant_shard_id: TenantShardId,
|
||||
) -> Result<(), SecondaryTenantError> {
|
||||
pub async fn upload_tenant(&self, tenant_shard_id: TenantShardId) -> anyhow::Result<()> {
|
||||
self.dispatch(&self.upload_req_tx, UploadCommand::Upload(tenant_shard_id))
|
||||
.await
|
||||
}
|
||||
pub(crate) async fn download_tenant(
|
||||
&self,
|
||||
tenant_shard_id: TenantShardId,
|
||||
) -> Result<(), SecondaryTenantError> {
|
||||
pub async fn download_tenant(&self, tenant_shard_id: TenantShardId) -> anyhow::Result<()> {
|
||||
self.dispatch(
|
||||
&self.download_req_tx,
|
||||
DownloadCommand::Download(tenant_shard_id),
|
||||
|
||||
@@ -35,7 +35,7 @@ use super::{
|
||||
self, period_jitter, period_warmup, Completion, JobGenerator, SchedulingResult,
|
||||
TenantBackgroundJobs,
|
||||
},
|
||||
GetTenantError, SecondaryTenant, SecondaryTenantError,
|
||||
SecondaryTenant,
|
||||
};
|
||||
|
||||
use crate::tenant::{
|
||||
@@ -470,16 +470,15 @@ impl JobGenerator<PendingDownload, RunningDownload, CompleteDownload, DownloadCo
|
||||
result
|
||||
}
|
||||
|
||||
fn on_command(
|
||||
&mut self,
|
||||
command: DownloadCommand,
|
||||
) -> Result<PendingDownload, SecondaryTenantError> {
|
||||
fn on_command(&mut self, command: DownloadCommand) -> anyhow::Result<PendingDownload> {
|
||||
let tenant_shard_id = command.get_tenant_shard_id();
|
||||
|
||||
let tenant = self
|
||||
.tenant_manager
|
||||
.get_secondary_tenant_shard(*tenant_shard_id)
|
||||
.ok_or(GetTenantError::ShardNotFound(*tenant_shard_id))?;
|
||||
.get_secondary_tenant_shard(*tenant_shard_id);
|
||||
let Some(tenant) = tenant else {
|
||||
return Err(anyhow::anyhow!("Not found or not in Secondary mode"));
|
||||
};
|
||||
|
||||
Ok(PendingDownload {
|
||||
target_time: None,
|
||||
@@ -1183,7 +1182,6 @@ impl<'a> TenantDownloader<'a> {
|
||||
&layer.name,
|
||||
&layer.metadata,
|
||||
&local_path,
|
||||
&self.secondary_state.gate,
|
||||
&self.secondary_state.cancel,
|
||||
ctx,
|
||||
)
|
||||
|
||||
@@ -28,7 +28,7 @@ use super::{
|
||||
self, period_jitter, period_warmup, JobGenerator, RunningJob, SchedulingResult,
|
||||
TenantBackgroundJobs,
|
||||
},
|
||||
CommandRequest, SecondaryTenantError, UploadCommand,
|
||||
CommandRequest, UploadCommand,
|
||||
};
|
||||
use tokio_util::sync::CancellationToken;
|
||||
use tracing::{info_span, instrument, Instrument};
|
||||
@@ -279,10 +279,7 @@ impl JobGenerator<UploadPending, WriteInProgress, WriteComplete, UploadCommand>
|
||||
}.instrument(info_span!(parent: None, "heatmap_upload", tenant_id=%tenant_shard_id.tenant_id, shard_id=%tenant_shard_id.shard_slug()))))
|
||||
}
|
||||
|
||||
fn on_command(
|
||||
&mut self,
|
||||
command: UploadCommand,
|
||||
) -> Result<UploadPending, SecondaryTenantError> {
|
||||
fn on_command(&mut self, command: UploadCommand) -> anyhow::Result<UploadPending> {
|
||||
let tenant_shard_id = command.get_tenant_shard_id();
|
||||
|
||||
tracing::info!(
|
||||
@@ -290,7 +287,8 @@ impl JobGenerator<UploadPending, WriteInProgress, WriteComplete, UploadCommand>
|
||||
"Starting heatmap write on command");
|
||||
let tenant = self
|
||||
.tenant_manager
|
||||
.get_attached_tenant_shard(*tenant_shard_id)?;
|
||||
.get_attached_tenant_shard(*tenant_shard_id)
|
||||
.map_err(|e| anyhow::anyhow!(e))?;
|
||||
if !tenant.is_active() {
|
||||
return Err(GetTenantError::NotActive(*tenant_shard_id).into());
|
||||
}
|
||||
|
||||
@@ -12,7 +12,7 @@ use tokio::task::JoinSet;
|
||||
use tokio_util::sync::CancellationToken;
|
||||
use utils::{completion::Barrier, yielding_loop::yielding_loop};
|
||||
|
||||
use super::{CommandRequest, CommandResponse, SecondaryTenantError};
|
||||
use super::{CommandRequest, CommandResponse};
|
||||
|
||||
/// Scheduling interval is the time between calls to JobGenerator::schedule.
|
||||
/// When we schedule jobs, the job generator may provide a hint of its preferred
|
||||
@@ -112,7 +112,7 @@ where
|
||||
|
||||
/// Called when a command is received. A job will be spawned immediately if the return
|
||||
/// value is Some, ignoring concurrency limits and the pending queue.
|
||||
fn on_command(&mut self, cmd: CMD) -> Result<PJ, SecondaryTenantError>;
|
||||
fn on_command(&mut self, cmd: CMD) -> anyhow::Result<PJ>;
|
||||
}
|
||||
|
||||
/// [`JobGenerator`] returns this to provide pending jobs, and hints about scheduling
|
||||
|
||||
@@ -555,12 +555,13 @@ impl InMemoryLayer {
|
||||
timeline_id: TimelineId,
|
||||
tenant_shard_id: TenantShardId,
|
||||
start_lsn: Lsn,
|
||||
gate: &utils::sync::gate::Gate,
|
||||
gate_guard: utils::sync::gate::GateGuard,
|
||||
ctx: &RequestContext,
|
||||
) -> Result<InMemoryLayer> {
|
||||
trace!("initializing new empty InMemoryLayer for writing on timeline {timeline_id} at {start_lsn}");
|
||||
|
||||
let file = EphemeralFile::create(conf, tenant_shard_id, timeline_id, gate, ctx).await?;
|
||||
let file =
|
||||
EphemeralFile::create(conf, tenant_shard_id, timeline_id, gate_guard, ctx).await?;
|
||||
let key = InMemoryLayerFileId(file.page_cache_file_id());
|
||||
|
||||
Ok(InMemoryLayer {
|
||||
|
||||
@@ -1149,7 +1149,6 @@ impl LayerInner {
|
||||
&self.desc.layer_name(),
|
||||
&self.metadata(),
|
||||
&self.path,
|
||||
&timeline.gate,
|
||||
&timeline.cancel,
|
||||
ctx,
|
||||
)
|
||||
|
||||
@@ -471,14 +471,14 @@ async fn ingest_housekeeping_loop(tenant: Arc<Tenant>, cancel: CancellationToken
|
||||
|
||||
// TODO: rename the background loop kind to something more generic, like, tenant housekeeping.
|
||||
// Or just spawn another background loop for this throttle, it's not like it's super costly.
|
||||
info_span!(parent: None, "pagestream_throttle", tenant_id=%tenant.tenant_shard_id, shard_id=%tenant.tenant_shard_id.shard_slug()).in_scope(|| {
|
||||
info_span!(parent: None, "timeline_get_throttle", tenant_id=%tenant.tenant_shard_id, shard_id=%tenant.tenant_shard_id.shard_slug()).in_scope(|| {
|
||||
let now = Instant::now();
|
||||
let prev = std::mem::replace(&mut last_throttle_flag_reset_at, now);
|
||||
let Stats { count_accounted_start, count_accounted_finish, count_throttled, sum_throttled_usecs} = tenant.pagestream_throttle.reset_stats();
|
||||
let Stats { count_accounted_start, count_accounted_finish, count_throttled, sum_throttled_usecs} = tenant.timeline_get_throttle.reset_stats();
|
||||
if count_throttled == 0 {
|
||||
return;
|
||||
}
|
||||
let allowed_rps = tenant.pagestream_throttle.steady_rps();
|
||||
let allowed_rps = tenant.timeline_get_throttle.steady_rps();
|
||||
let delta = now - prev;
|
||||
info!(
|
||||
n_seconds=%format_args!("{:.3}", delta.as_secs_f64()),
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
use std::{
|
||||
str::FromStr,
|
||||
sync::{
|
||||
atomic::{AtomicU64, Ordering},
|
||||
Arc,
|
||||
@@ -7,8 +8,12 @@ use std::{
|
||||
};
|
||||
|
||||
use arc_swap::ArcSwap;
|
||||
use enumset::EnumSet;
|
||||
use tracing::error;
|
||||
use utils::leaky_bucket::{LeakyBucketConfig, RateLimiter};
|
||||
|
||||
use crate::{context::RequestContext, task_mgr::TaskKind};
|
||||
|
||||
/// Throttle for `async` functions.
|
||||
///
|
||||
/// Runtime reconfigurable.
|
||||
@@ -30,7 +35,7 @@ pub struct Throttle<M: Metric> {
|
||||
}
|
||||
|
||||
pub struct Inner {
|
||||
enabled: bool,
|
||||
task_kinds: EnumSet<TaskKind>,
|
||||
rate_limiter: Arc<RateLimiter>,
|
||||
}
|
||||
|
||||
@@ -74,12 +79,26 @@ where
|
||||
}
|
||||
fn new_inner(config: Config) -> Inner {
|
||||
let Config {
|
||||
enabled,
|
||||
task_kinds,
|
||||
initial,
|
||||
refill_interval,
|
||||
refill_amount,
|
||||
max,
|
||||
} = config;
|
||||
let task_kinds: EnumSet<TaskKind> = task_kinds
|
||||
.iter()
|
||||
.filter_map(|s| match TaskKind::from_str(s) {
|
||||
Ok(v) => Some(v),
|
||||
Err(e) => {
|
||||
// TODO: avoid this failure mode
|
||||
error!(
|
||||
"cannot parse task kind, ignoring for rate limiting {}",
|
||||
utils::error::report_compact_sources(&e)
|
||||
);
|
||||
None
|
||||
}
|
||||
})
|
||||
.collect();
|
||||
|
||||
// steady rate, we expect `refill_amount` requests per `refill_interval`.
|
||||
// dividing gives us the rps.
|
||||
@@ -93,7 +112,7 @@ where
|
||||
let rate_limiter = RateLimiter::with_initial_tokens(config, f64::from(initial_tokens));
|
||||
|
||||
Inner {
|
||||
enabled: enabled.is_enabled(),
|
||||
task_kinds,
|
||||
rate_limiter: Arc::new(rate_limiter),
|
||||
}
|
||||
}
|
||||
@@ -122,13 +141,11 @@ where
|
||||
self.inner.load().rate_limiter.steady_rps()
|
||||
}
|
||||
|
||||
pub async fn throttle(&self, key_count: usize) -> Option<Duration> {
|
||||
pub async fn throttle(&self, ctx: &RequestContext, key_count: usize) -> Option<Duration> {
|
||||
let inner = self.inner.load_full(); // clones the `Inner` Arc
|
||||
|
||||
if !inner.enabled {
|
||||
if !inner.task_kinds.contains(ctx.task_kind()) {
|
||||
return None;
|
||||
}
|
||||
|
||||
};
|
||||
let start = std::time::Instant::now();
|
||||
|
||||
self.metric.accounting_start();
|
||||
|
||||
@@ -208,8 +208,8 @@ fn drop_wlock<T>(rlock: tokio::sync::RwLockWriteGuard<'_, T>) {
|
||||
/// The outward-facing resources required to build a Timeline
|
||||
pub struct TimelineResources {
|
||||
pub remote_client: RemoteTimelineClient,
|
||||
pub pagestream_throttle:
|
||||
Arc<crate::tenant::throttle::Throttle<crate::metrics::tenant_throttling::Pagestream>>,
|
||||
pub timeline_get_throttle:
|
||||
Arc<crate::tenant::throttle::Throttle<crate::metrics::tenant_throttling::TimelineGet>>,
|
||||
pub l0_flush_global_state: l0_flush::L0FlushGlobalState,
|
||||
}
|
||||
|
||||
@@ -411,9 +411,9 @@ pub struct Timeline {
|
||||
/// Timeline deletion will acquire both compaction and gc locks in whatever order.
|
||||
gc_lock: tokio::sync::Mutex<()>,
|
||||
|
||||
/// Cloned from [`super::Tenant::pagestream_throttle`] on construction.
|
||||
pub(crate) pagestream_throttle:
|
||||
Arc<crate::tenant::throttle::Throttle<crate::metrics::tenant_throttling::Pagestream>>,
|
||||
/// Cloned from [`super::Tenant::timeline_get_throttle`] on construction.
|
||||
timeline_get_throttle:
|
||||
Arc<crate::tenant::throttle::Throttle<crate::metrics::tenant_throttling::TimelineGet>>,
|
||||
|
||||
/// Size estimator for aux file v2
|
||||
pub(crate) aux_file_size_estimator: AuxFileSizeEstimator,
|
||||
@@ -949,7 +949,7 @@ impl Timeline {
|
||||
/// If a remote layer file is needed, it is downloaded as part of this
|
||||
/// call.
|
||||
///
|
||||
/// This method enforces [`Self::pagestream_throttle`] internally.
|
||||
/// This method enforces [`Self::timeline_get_throttle`] internally.
|
||||
///
|
||||
/// NOTE: It is considered an error to 'get' a key that doesn't exist. The
|
||||
/// abstraction above this needs to store suitable metadata to track what
|
||||
@@ -977,6 +977,8 @@ impl Timeline {
|
||||
// page_service.
|
||||
debug_assert!(!self.shard_identity.is_key_disposable(&key));
|
||||
|
||||
self.timeline_get_throttle.throttle(ctx, 1).await;
|
||||
|
||||
let keyspace = KeySpace {
|
||||
ranges: vec![key..key.next()],
|
||||
};
|
||||
@@ -1056,6 +1058,14 @@ impl Timeline {
|
||||
.for_task_kind(ctx.task_kind())
|
||||
.map(|metric| (metric, Instant::now()));
|
||||
|
||||
// start counting after throttle so that throttle time
|
||||
// is always less than observation time and we don't
|
||||
// underflow when computing `ex_throttled` below.
|
||||
let throttled = self
|
||||
.timeline_get_throttle
|
||||
.throttle(ctx, key_count as usize)
|
||||
.await;
|
||||
|
||||
let res = self
|
||||
.get_vectored_impl(
|
||||
keyspace.clone(),
|
||||
@@ -1067,7 +1077,23 @@ impl Timeline {
|
||||
|
||||
if let Some((metric, start)) = start {
|
||||
let elapsed = start.elapsed();
|
||||
metric.observe(elapsed.as_secs_f64());
|
||||
let ex_throttled = if let Some(throttled) = throttled {
|
||||
elapsed.checked_sub(throttled)
|
||||
} else {
|
||||
Some(elapsed)
|
||||
};
|
||||
|
||||
if let Some(ex_throttled) = ex_throttled {
|
||||
metric.observe(ex_throttled.as_secs_f64());
|
||||
} else {
|
||||
use utils::rate_limit::RateLimit;
|
||||
static LOGGED: Lazy<Mutex<RateLimit>> =
|
||||
Lazy::new(|| Mutex::new(RateLimit::new(Duration::from_secs(10))));
|
||||
let mut rate_limit = LOGGED.lock().unwrap();
|
||||
rate_limit.call(|| {
|
||||
warn!("error deducting time spent throttled; this message is logged at a global rate limit");
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
res
|
||||
@@ -1112,6 +1138,16 @@ impl Timeline {
|
||||
.for_task_kind(ctx.task_kind())
|
||||
.map(ScanLatencyOngoingRecording::start_recording);
|
||||
|
||||
// start counting after throttle so that throttle time
|
||||
// is always less than observation time and we don't
|
||||
// underflow when computing the `ex_throttled` value in
|
||||
// `recording.observe(throttled)` below.
|
||||
let throttled = self
|
||||
.timeline_get_throttle
|
||||
// assume scan = 1 quota for now until we find a better way to process this
|
||||
.throttle(ctx, 1)
|
||||
.await;
|
||||
|
||||
let vectored_res = self
|
||||
.get_vectored_impl(
|
||||
keyspace.clone(),
|
||||
@@ -1122,7 +1158,7 @@ impl Timeline {
|
||||
.await;
|
||||
|
||||
if let Some(recording) = start {
|
||||
recording.observe();
|
||||
recording.observe(throttled);
|
||||
}
|
||||
|
||||
vectored_res
|
||||
@@ -2338,7 +2374,7 @@ impl Timeline {
|
||||
|
||||
standby_horizon: AtomicLsn::new(0),
|
||||
|
||||
pagestream_throttle: resources.pagestream_throttle,
|
||||
timeline_get_throttle: resources.timeline_get_throttle,
|
||||
|
||||
aux_file_size_estimator: AuxFileSizeEstimator::new(aux_file_metrics),
|
||||
|
||||
@@ -3455,6 +3491,7 @@ impl Timeline {
|
||||
ctx: &RequestContext,
|
||||
) -> anyhow::Result<Arc<InMemoryLayer>> {
|
||||
let mut guard = self.layers.write().await;
|
||||
let gate_guard = self.gate.enter().context("enter gate for inmem layer")?;
|
||||
|
||||
let last_record_lsn = self.get_last_record_lsn();
|
||||
ensure!(
|
||||
@@ -3471,7 +3508,7 @@ impl Timeline {
|
||||
self.conf,
|
||||
self.timeline_id,
|
||||
self.tenant_shard_id,
|
||||
&self.gate,
|
||||
gate_guard,
|
||||
ctx,
|
||||
)
|
||||
.await?;
|
||||
|
||||
@@ -298,7 +298,7 @@ impl DeleteTimelineFlow {
|
||||
None, // Ancestor is not needed for deletion.
|
||||
TimelineResources {
|
||||
remote_client,
|
||||
pagestream_throttle: tenant.pagestream_throttle.clone(),
|
||||
timeline_get_throttle: tenant.timeline_get_throttle.clone(),
|
||||
l0_flush_global_state: tenant.l0_flush_global_state.clone(),
|
||||
},
|
||||
// Important. We dont pass ancestor above because it can be missing.
|
||||
|
||||
@@ -129,23 +129,22 @@ impl Flow {
|
||||
}
|
||||
|
||||
// Import SLRUs
|
||||
if self.timeline.tenant_shard_id.is_shard_zero() {
|
||||
// pg_xact (01:00 keyspace)
|
||||
self.import_slru(SlruKind::Clog, &self.storage.pgdata().join("pg_xact"))
|
||||
.await?;
|
||||
// pg_multixact/members (01:01 keyspace)
|
||||
self.import_slru(
|
||||
SlruKind::MultiXactMembers,
|
||||
&self.storage.pgdata().join("pg_multixact/members"),
|
||||
)
|
||||
|
||||
// pg_xact (01:00 keyspace)
|
||||
self.import_slru(SlruKind::Clog, &self.storage.pgdata().join("pg_xact"))
|
||||
.await?;
|
||||
// pg_multixact/offsets (01:02 keyspace)
|
||||
self.import_slru(
|
||||
SlruKind::MultiXactOffsets,
|
||||
&self.storage.pgdata().join("pg_multixact/offsets"),
|
||||
)
|
||||
.await?;
|
||||
}
|
||||
// pg_multixact/members (01:01 keyspace)
|
||||
self.import_slru(
|
||||
SlruKind::MultiXactMembers,
|
||||
&self.storage.pgdata().join("pg_multixact/members"),
|
||||
)
|
||||
.await?;
|
||||
// pg_multixact/offsets (01:02 keyspace)
|
||||
self.import_slru(
|
||||
SlruKind::MultiXactOffsets,
|
||||
&self.storage.pgdata().join("pg_multixact/offsets"),
|
||||
)
|
||||
.await?;
|
||||
|
||||
// Import pg_twophase.
|
||||
// TODO: as empty
|
||||
@@ -303,8 +302,6 @@ impl Flow {
|
||||
}
|
||||
|
||||
async fn import_slru(&mut self, kind: SlruKind, path: &RemotePath) -> anyhow::Result<()> {
|
||||
assert!(self.timeline.tenant_shard_id.is_shard_zero());
|
||||
|
||||
let segments = self.storage.listfilesindir(path).await?;
|
||||
let segments: Vec<(String, u32, usize)> = segments
|
||||
.into_iter()
|
||||
@@ -340,6 +337,7 @@ impl Flow {
|
||||
debug!(%p, segno=%segno, %size, %start_key, %end_key, "scheduling SLRU segment");
|
||||
self.tasks
|
||||
.push(AnyImportTask::SlruBlocks(ImportSlruBlocksTask::new(
|
||||
*self.timeline.get_shard_identity(),
|
||||
start_key..end_key,
|
||||
&p,
|
||||
self.storage.clone(),
|
||||
@@ -633,14 +631,21 @@ impl ImportTask for ImportRelBlocksTask {
|
||||
}
|
||||
|
||||
struct ImportSlruBlocksTask {
|
||||
shard_identity: ShardIdentity,
|
||||
key_range: Range<Key>,
|
||||
path: RemotePath,
|
||||
storage: RemoteStorageWrapper,
|
||||
}
|
||||
|
||||
impl ImportSlruBlocksTask {
|
||||
fn new(key_range: Range<Key>, path: &RemotePath, storage: RemoteStorageWrapper) -> Self {
|
||||
fn new(
|
||||
shard_identity: ShardIdentity,
|
||||
key_range: Range<Key>,
|
||||
path: &RemotePath,
|
||||
storage: RemoteStorageWrapper,
|
||||
) -> Self {
|
||||
ImportSlruBlocksTask {
|
||||
shard_identity,
|
||||
key_range,
|
||||
path: path.clone(),
|
||||
storage,
|
||||
@@ -668,13 +673,17 @@ impl ImportTask for ImportSlruBlocksTask {
|
||||
let mut file_offset = 0;
|
||||
while blknum < end_blk {
|
||||
let key = slru_block_to_key(kind, segno, blknum);
|
||||
assert!(
|
||||
!self.shard_identity.is_key_disposable(&key),
|
||||
"SLRU keys need to go into every shard"
|
||||
);
|
||||
let buf = &buf[file_offset..(file_offset + 8192)];
|
||||
file_offset += 8192;
|
||||
layer_writer
|
||||
.put_image(key, Bytes::copy_from_slice(buf), ctx)
|
||||
.await?;
|
||||
nimages += 1;
|
||||
blknum += 1;
|
||||
nimages += 1;
|
||||
}
|
||||
Ok(nimages)
|
||||
}
|
||||
|
||||
@@ -182,7 +182,7 @@ impl OpenLayerManager {
|
||||
conf: &'static PageServerConf,
|
||||
timeline_id: TimelineId,
|
||||
tenant_shard_id: TenantShardId,
|
||||
gate: &utils::sync::gate::Gate,
|
||||
gate_guard: utils::sync::gate::GateGuard,
|
||||
ctx: &RequestContext,
|
||||
) -> anyhow::Result<Arc<InMemoryLayer>> {
|
||||
ensure!(lsn.is_aligned());
|
||||
@@ -212,9 +212,15 @@ impl OpenLayerManager {
|
||||
lsn
|
||||
);
|
||||
|
||||
let new_layer =
|
||||
InMemoryLayer::create(conf, timeline_id, tenant_shard_id, start_lsn, gate, ctx)
|
||||
.await?;
|
||||
let new_layer = InMemoryLayer::create(
|
||||
conf,
|
||||
timeline_id,
|
||||
tenant_shard_id,
|
||||
start_lsn,
|
||||
gate_guard,
|
||||
ctx,
|
||||
)
|
||||
.await?;
|
||||
let layer = Arc::new(new_layer);
|
||||
|
||||
self.layer_map.open_layer = Some(layer.clone());
|
||||
|
||||
@@ -20,7 +20,7 @@ use camino::{Utf8Path, Utf8PathBuf};
|
||||
use once_cell::sync::OnceCell;
|
||||
use owned_buffers_io::aligned_buffer::buffer::AlignedBuffer;
|
||||
use owned_buffers_io::aligned_buffer::{AlignedBufferMut, AlignedSlice, ConstAlign};
|
||||
use owned_buffers_io::io_buf_aligned::{IoBufAligned, IoBufAlignedMut};
|
||||
use owned_buffers_io::io_buf_aligned::IoBufAlignedMut;
|
||||
use owned_buffers_io::io_buf_ext::FullSlice;
|
||||
use pageserver_api::config::defaults::DEFAULT_IO_BUFFER_ALIGNMENT;
|
||||
use pageserver_api::shard::TenantShardId;
|
||||
@@ -63,6 +63,9 @@ pub(crate) mod owned_buffers_io {
|
||||
pub(crate) mod io_buf_ext;
|
||||
pub(crate) mod slice;
|
||||
pub(crate) mod write;
|
||||
pub(crate) mod util {
|
||||
pub(crate) mod size_tracking_writer;
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
@@ -218,7 +221,7 @@ impl VirtualFile {
|
||||
self.inner.read_exact_at_page(page, offset, ctx).await
|
||||
}
|
||||
|
||||
pub async fn write_all_at<Buf: IoBufAligned + Send>(
|
||||
pub async fn write_all_at<Buf: IoBuf + Send>(
|
||||
&self,
|
||||
buf: FullSlice<Buf>,
|
||||
offset: u64,
|
||||
@@ -1322,14 +1325,14 @@ impl Drop for VirtualFileInner {
|
||||
}
|
||||
|
||||
impl OwnedAsyncWriter for VirtualFile {
|
||||
async fn write_all_at<Buf: IoBufAligned + Send>(
|
||||
&self,
|
||||
#[inline(always)]
|
||||
async fn write_all<Buf: IoBuf + Send>(
|
||||
&mut self,
|
||||
buf: FullSlice<Buf>,
|
||||
offset: u64,
|
||||
ctx: &RequestContext,
|
||||
) -> std::io::Result<FullSlice<Buf>> {
|
||||
let (buf, res) = VirtualFile::write_all_at(self, buf, offset, ctx).await;
|
||||
res.map(|_| buf)
|
||||
) -> std::io::Result<(usize, FullSlice<Buf>)> {
|
||||
let (buf, res) = VirtualFile::write_all(self, buf, ctx).await;
|
||||
res.map(move |v| (v, buf))
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1448,7 +1451,7 @@ mod tests {
|
||||
}
|
||||
}
|
||||
}
|
||||
async fn write_all_at<Buf: IoBufAligned + Send>(
|
||||
async fn write_all_at<Buf: IoBuf + Send>(
|
||||
&self,
|
||||
buf: FullSlice<Buf>,
|
||||
offset: u64,
|
||||
@@ -1591,7 +1594,6 @@ mod tests {
|
||||
&ctx,
|
||||
)
|
||||
.await?;
|
||||
|
||||
file_a
|
||||
.write_all(b"foobar".to_vec().slice_len(), &ctx)
|
||||
.await?;
|
||||
@@ -1650,10 +1652,10 @@ mod tests {
|
||||
)
|
||||
.await?;
|
||||
file_b
|
||||
.write_all_at(IoBuffer::from(b"BAR").slice_len(), 3, &ctx)
|
||||
.write_all_at(b"BAR".to_vec().slice_len(), 3, &ctx)
|
||||
.await?;
|
||||
file_b
|
||||
.write_all_at(IoBuffer::from(b"FOO").slice_len(), 0, &ctx)
|
||||
.write_all_at(b"FOO".to_vec().slice_len(), 0, &ctx)
|
||||
.await?;
|
||||
|
||||
assert_eq!(file_b.read_string_at(2, 3, &ctx).await?, "OBA");
|
||||
|
||||
@@ -4,7 +4,7 @@ pub trait Alignment: std::marker::Unpin + 'static {
|
||||
}
|
||||
|
||||
/// Alignment at compile time.
|
||||
#[derive(Debug, Clone, Copy)]
|
||||
#[derive(Debug)]
|
||||
pub struct ConstAlign<const A: usize>;
|
||||
|
||||
impl<const A: usize> Alignment for ConstAlign<A> {
|
||||
@@ -14,7 +14,7 @@ impl<const A: usize> Alignment for ConstAlign<A> {
|
||||
}
|
||||
|
||||
/// Alignment at run time.
|
||||
#[derive(Debug, Clone, Copy)]
|
||||
#[derive(Debug)]
|
||||
pub struct RuntimeAlign {
|
||||
align: usize,
|
||||
}
|
||||
|
||||
@@ -3,10 +3,9 @@ use std::{
|
||||
sync::Arc,
|
||||
};
|
||||
|
||||
use super::{alignment::Alignment, raw::RawAlignedBuffer, AlignedBufferMut, ConstAlign};
|
||||
use super::{alignment::Alignment, raw::RawAlignedBuffer};
|
||||
|
||||
/// An shared, immutable aligned buffer type.
|
||||
#[derive(Clone, Debug)]
|
||||
pub struct AlignedBuffer<A: Alignment> {
|
||||
/// Shared raw buffer.
|
||||
raw: Arc<RawAlignedBuffer<A>>,
|
||||
@@ -87,13 +86,6 @@ impl<A: Alignment> AlignedBuffer<A> {
|
||||
range: begin..end,
|
||||
}
|
||||
}
|
||||
|
||||
/// Returns the mutable aligned buffer, if the immutable aligned buffer
|
||||
/// has exactly one strong reference. Otherwise returns `None`.
|
||||
pub fn into_mut(self) -> Option<AlignedBufferMut<A>> {
|
||||
let raw = Arc::into_inner(self.raw)?;
|
||||
Some(AlignedBufferMut::from_raw(raw))
|
||||
}
|
||||
}
|
||||
|
||||
impl<A: Alignment> Deref for AlignedBuffer<A> {
|
||||
@@ -116,14 +108,6 @@ impl<A: Alignment> PartialEq<[u8]> for AlignedBuffer<A> {
|
||||
}
|
||||
}
|
||||
|
||||
impl<const A: usize, const N: usize> From<&[u8; N]> for AlignedBuffer<ConstAlign<A>> {
|
||||
fn from(value: &[u8; N]) -> Self {
|
||||
let mut buf = AlignedBufferMut::with_capacity(N);
|
||||
buf.extend_from_slice(value);
|
||||
buf.freeze()
|
||||
}
|
||||
}
|
||||
|
||||
/// SAFETY: the underlying buffer references a stable memory region.
|
||||
unsafe impl<A: Alignment> tokio_epoll_uring::IoBuf for AlignedBuffer<A> {
|
||||
fn stable_ptr(&self) -> *const u8 {
|
||||
|
||||
@@ -1,7 +1,4 @@
|
||||
use std::{
|
||||
mem::MaybeUninit,
|
||||
ops::{Deref, DerefMut},
|
||||
};
|
||||
use std::ops::{Deref, DerefMut};
|
||||
|
||||
use super::{
|
||||
alignment::{Alignment, ConstAlign},
|
||||
@@ -49,11 +46,6 @@ impl<const A: usize> AlignedBufferMut<ConstAlign<A>> {
|
||||
}
|
||||
|
||||
impl<A: Alignment> AlignedBufferMut<A> {
|
||||
/// Constructs a mutable aligned buffer from raw.
|
||||
pub(super) fn from_raw(raw: RawAlignedBuffer<A>) -> Self {
|
||||
AlignedBufferMut { raw }
|
||||
}
|
||||
|
||||
/// Returns the total number of bytes the buffer can hold.
|
||||
#[inline]
|
||||
pub fn capacity(&self) -> usize {
|
||||
@@ -136,39 +128,6 @@ impl<A: Alignment> AlignedBufferMut<A> {
|
||||
let len = self.len();
|
||||
AlignedBuffer::from_raw(self.raw, 0..len)
|
||||
}
|
||||
|
||||
/// Clones and appends all elements in a slice to the buffer. Reserves additional capacity as needed.
|
||||
#[inline]
|
||||
pub fn extend_from_slice(&mut self, extend: &[u8]) {
|
||||
let cnt = extend.len();
|
||||
self.reserve(cnt);
|
||||
|
||||
// SAFETY: we already reserved additional `cnt` bytes, safe to perform memcpy.
|
||||
unsafe {
|
||||
let dst = self.spare_capacity_mut();
|
||||
// Reserved above
|
||||
debug_assert!(dst.len() >= cnt);
|
||||
|
||||
core::ptr::copy_nonoverlapping(extend.as_ptr(), dst.as_mut_ptr().cast(), cnt);
|
||||
}
|
||||
// SAFETY: We do have at least `cnt` bytes remaining before advance.
|
||||
unsafe {
|
||||
bytes::BufMut::advance_mut(self, cnt);
|
||||
}
|
||||
}
|
||||
|
||||
/// Returns the remaining spare capacity of the vector as a slice of `MaybeUninit<u8>`.
|
||||
#[inline]
|
||||
fn spare_capacity_mut(&mut self) -> &mut [MaybeUninit<u8>] {
|
||||
// SAFETY: we guarantees that the `Self::capacity()` bytes from
|
||||
// `Self::as_mut_ptr()` are allocated.
|
||||
unsafe {
|
||||
let ptr = self.as_mut_ptr().add(self.len());
|
||||
let len = self.capacity() - self.len();
|
||||
|
||||
core::slice::from_raw_parts_mut(ptr.cast(), len)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<A: Alignment> Deref for AlignedBufferMut<A> {
|
||||
|
||||
@@ -1,15 +1,9 @@
|
||||
use tokio_epoll_uring::{IoBuf, IoBufMut};
|
||||
use tokio_epoll_uring::IoBufMut;
|
||||
|
||||
use crate::virtual_file::{IoBuffer, IoBufferMut, PageWriteGuardBuf};
|
||||
use crate::virtual_file::{IoBufferMut, PageWriteGuardBuf};
|
||||
|
||||
/// A marker trait for a mutable aligned buffer type.
|
||||
pub trait IoBufAlignedMut: IoBufMut {}
|
||||
|
||||
/// A marker trait for an aligned buffer type.
|
||||
pub trait IoBufAligned: IoBuf {}
|
||||
|
||||
impl IoBufAlignedMut for IoBufferMut {}
|
||||
|
||||
impl IoBufAligned for IoBuffer {}
|
||||
|
||||
impl IoBufAlignedMut for PageWriteGuardBuf {}
|
||||
|
||||
@@ -5,8 +5,6 @@ use bytes::{Bytes, BytesMut};
|
||||
use std::ops::{Deref, Range};
|
||||
use tokio_epoll_uring::{BoundedBuf, IoBuf, Slice};
|
||||
|
||||
use super::write::CheapCloneForRead;
|
||||
|
||||
/// The true owned equivalent for Rust [`slice`]. Use this for the write path.
|
||||
///
|
||||
/// Unlike [`tokio_epoll_uring::Slice`], which we unfortunately inherited from `tokio-uring`,
|
||||
@@ -45,18 +43,6 @@ where
|
||||
}
|
||||
}
|
||||
|
||||
impl<B> CheapCloneForRead for FullSlice<B>
|
||||
where
|
||||
B: IoBuf + CheapCloneForRead,
|
||||
{
|
||||
fn cheap_clone(&self) -> Self {
|
||||
let bounds = self.slice.bounds();
|
||||
let clone = self.slice.get_ref().cheap_clone();
|
||||
let slice = clone.slice(bounds);
|
||||
Self { slice }
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) trait IoBufExt {
|
||||
/// Get a [`FullSlice`] for the entire buffer, i.e., `self[..]` or `self[0..self.len()]`.
|
||||
fn slice_len(self) -> FullSlice<Self>
|
||||
|
||||
@@ -0,0 +1,50 @@
|
||||
use crate::{
|
||||
context::RequestContext,
|
||||
virtual_file::owned_buffers_io::{io_buf_ext::FullSlice, write::OwnedAsyncWriter},
|
||||
};
|
||||
use tokio_epoll_uring::IoBuf;
|
||||
|
||||
pub struct Writer<W> {
|
||||
dst: W,
|
||||
bytes_amount: u64,
|
||||
}
|
||||
|
||||
impl<W> Writer<W> {
|
||||
pub fn new(dst: W) -> Self {
|
||||
Self {
|
||||
dst,
|
||||
bytes_amount: 0,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn bytes_written(&self) -> u64 {
|
||||
self.bytes_amount
|
||||
}
|
||||
|
||||
pub fn as_inner(&self) -> &W {
|
||||
&self.dst
|
||||
}
|
||||
|
||||
/// Returns the wrapped `VirtualFile` object as well as the number
|
||||
/// of bytes that were written to it through this object.
|
||||
#[cfg_attr(target_os = "macos", allow(dead_code))]
|
||||
pub fn into_inner(self) -> (u64, W) {
|
||||
(self.bytes_amount, self.dst)
|
||||
}
|
||||
}
|
||||
|
||||
impl<W> OwnedAsyncWriter for Writer<W>
|
||||
where
|
||||
W: OwnedAsyncWriter,
|
||||
{
|
||||
#[inline(always)]
|
||||
async fn write_all<Buf: IoBuf + Send>(
|
||||
&mut self,
|
||||
buf: FullSlice<Buf>,
|
||||
ctx: &RequestContext,
|
||||
) -> std::io::Result<(usize, FullSlice<Buf>)> {
|
||||
let (nwritten, buf) = self.dst.write_all(buf, ctx).await?;
|
||||
self.bytes_amount += u64::try_from(nwritten).unwrap();
|
||||
Ok((nwritten, buf))
|
||||
}
|
||||
}
|
||||
@@ -1,88 +1,55 @@
|
||||
mod flush;
|
||||
use std::sync::Arc;
|
||||
|
||||
use flush::FlushHandle;
|
||||
use bytes::BytesMut;
|
||||
use tokio_epoll_uring::IoBuf;
|
||||
|
||||
use crate::{
|
||||
context::RequestContext,
|
||||
virtual_file::{IoBuffer, IoBufferMut},
|
||||
};
|
||||
use crate::context::RequestContext;
|
||||
|
||||
use super::{
|
||||
io_buf_aligned::IoBufAligned,
|
||||
io_buf_ext::{FullSlice, IoBufExt},
|
||||
};
|
||||
|
||||
pub(crate) use flush::FlushControl;
|
||||
|
||||
pub(crate) trait CheapCloneForRead {
|
||||
/// Returns a cheap clone of the buffer.
|
||||
fn cheap_clone(&self) -> Self;
|
||||
}
|
||||
|
||||
impl CheapCloneForRead for IoBuffer {
|
||||
fn cheap_clone(&self) -> Self {
|
||||
// Cheap clone over an `Arc`.
|
||||
self.clone()
|
||||
}
|
||||
}
|
||||
use super::io_buf_ext::{FullSlice, IoBufExt};
|
||||
|
||||
/// A trait for doing owned-buffer write IO.
|
||||
/// Think [`tokio::io::AsyncWrite`] but with owned buffers.
|
||||
/// The owned buffers need to be aligned due to Direct IO requirements.
|
||||
pub trait OwnedAsyncWriter {
|
||||
fn write_all_at<Buf: IoBufAligned + Send>(
|
||||
&self,
|
||||
async fn write_all<Buf: IoBuf + Send>(
|
||||
&mut self,
|
||||
buf: FullSlice<Buf>,
|
||||
offset: u64,
|
||||
ctx: &RequestContext,
|
||||
) -> impl std::future::Future<Output = std::io::Result<FullSlice<Buf>>> + Send;
|
||||
) -> std::io::Result<(usize, FullSlice<Buf>)>;
|
||||
}
|
||||
|
||||
/// A wrapper aorund an [`OwnedAsyncWriter`] that uses a [`Buffer`] to batch
|
||||
/// small writes into larger writes of size [`Buffer::cap`].
|
||||
// TODO(yuchen): For large write, implementing buffer bypass for aligned parts of the write could be beneficial to throughput,
|
||||
// since we would avoid copying majority of the data into the internal buffer.
|
||||
pub struct BufferedWriter<B: Buffer, W> {
|
||||
writer: Arc<W>,
|
||||
///
|
||||
/// # Passthrough Of Large Writers
|
||||
///
|
||||
/// Calls to [`BufferedWriter::write_buffered`] that are larger than [`Buffer::cap`]
|
||||
/// cause the internal buffer to be flushed prematurely so that the large
|
||||
/// buffered write is passed through to the underlying [`OwnedAsyncWriter`].
|
||||
///
|
||||
/// This pass-through is generally beneficial for throughput, but if
|
||||
/// the storage backend of the [`OwnedAsyncWriter`] is a shared resource,
|
||||
/// unlimited large writes may cause latency or fairness issues.
|
||||
///
|
||||
/// In such cases, a different implementation that always buffers in memory
|
||||
/// may be preferable.
|
||||
pub struct BufferedWriter<B, W> {
|
||||
writer: W,
|
||||
/// invariant: always remains Some(buf) except
|
||||
/// - while IO is ongoing => goes back to Some() once the IO completed successfully
|
||||
/// - after an IO error => stays `None` forever
|
||||
///
|
||||
/// In these exceptional cases, it's `None`.
|
||||
mutable: Option<B>,
|
||||
/// A handle to the background flush task for writting data to disk.
|
||||
flush_handle: FlushHandle<B::IoBuf, W>,
|
||||
/// The number of bytes submitted to the background task.
|
||||
bytes_submitted: u64,
|
||||
buf: Option<B>,
|
||||
}
|
||||
|
||||
impl<B, Buf, W> BufferedWriter<B, W>
|
||||
where
|
||||
B: Buffer<IoBuf = Buf> + Send + 'static,
|
||||
Buf: IoBufAligned + Send + Sync + CheapCloneForRead,
|
||||
W: OwnedAsyncWriter + Send + Sync + 'static + std::fmt::Debug,
|
||||
B: Buffer<IoBuf = Buf> + Send,
|
||||
Buf: IoBuf + Send,
|
||||
W: OwnedAsyncWriter,
|
||||
{
|
||||
/// Creates a new buffered writer.
|
||||
///
|
||||
/// The `buf_new` function provides a way to initialize the owned buffers used by this writer.
|
||||
pub fn new(
|
||||
writer: Arc<W>,
|
||||
buf_new: impl Fn() -> B,
|
||||
gate_guard: utils::sync::gate::GateGuard,
|
||||
ctx: &RequestContext,
|
||||
) -> Self {
|
||||
pub fn new(writer: W, buf: B) -> Self {
|
||||
Self {
|
||||
writer: writer.clone(),
|
||||
mutable: Some(buf_new()),
|
||||
flush_handle: FlushHandle::spawn_new(
|
||||
writer,
|
||||
buf_new(),
|
||||
gate_guard,
|
||||
ctx.attached_child(),
|
||||
),
|
||||
bytes_submitted: 0,
|
||||
writer,
|
||||
buf: Some(buf),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -90,70 +57,87 @@ where
|
||||
&self.writer
|
||||
}
|
||||
|
||||
/// Returns the number of bytes submitted to the background flush task.
|
||||
pub fn bytes_submitted(&self) -> u64 {
|
||||
self.bytes_submitted
|
||||
}
|
||||
|
||||
/// Panics if used after any of the write paths returned an error
|
||||
pub fn inspect_mutable(&self) -> &B {
|
||||
self.mutable()
|
||||
}
|
||||
|
||||
/// Gets a reference to the maybe flushed read-only buffer.
|
||||
/// Returns `None` if the writer has not submitted any flush request.
|
||||
pub fn inspect_maybe_flushed(&self) -> Option<&FullSlice<Buf>> {
|
||||
self.flush_handle.maybe_flushed.as_ref()
|
||||
pub fn inspect_buffer(&self) -> &B {
|
||||
self.buf()
|
||||
}
|
||||
|
||||
#[cfg_attr(target_os = "macos", allow(dead_code))]
|
||||
pub async fn flush_and_into_inner(
|
||||
mut self,
|
||||
ctx: &RequestContext,
|
||||
) -> std::io::Result<(u64, Arc<W>)> {
|
||||
pub async fn flush_and_into_inner(mut self, ctx: &RequestContext) -> std::io::Result<W> {
|
||||
self.flush(ctx).await?;
|
||||
|
||||
let Self {
|
||||
mutable: buf,
|
||||
writer,
|
||||
mut flush_handle,
|
||||
bytes_submitted: bytes_amount,
|
||||
} = self;
|
||||
flush_handle.shutdown().await?;
|
||||
let Self { buf, writer } = self;
|
||||
assert!(buf.is_some());
|
||||
Ok((bytes_amount, writer))
|
||||
Ok(writer)
|
||||
}
|
||||
|
||||
/// Gets a reference to the mutable in-memory buffer.
|
||||
#[inline(always)]
|
||||
fn mutable(&self) -> &B {
|
||||
self.mutable
|
||||
fn buf(&self) -> &B {
|
||||
self.buf
|
||||
.as_ref()
|
||||
.expect("must not use after we returned an error")
|
||||
}
|
||||
|
||||
pub async fn write_buffered_borrowed(
|
||||
/// Guarantees that if Ok() is returned, all bytes in `chunk` have been accepted.
|
||||
#[cfg_attr(target_os = "macos", allow(dead_code))]
|
||||
pub async fn write_buffered<S: IoBuf + Send>(
|
||||
&mut self,
|
||||
chunk: &[u8],
|
||||
chunk: FullSlice<S>,
|
||||
ctx: &RequestContext,
|
||||
) -> std::io::Result<usize> {
|
||||
let (len, control) = self.write_buffered_borrowed_controlled(chunk, ctx).await?;
|
||||
if let Some(control) = control {
|
||||
control.release().await;
|
||||
) -> std::io::Result<(usize, FullSlice<S>)> {
|
||||
let chunk = chunk.into_raw_slice();
|
||||
|
||||
let chunk_len = chunk.len();
|
||||
// avoid memcpy for the middle of the chunk
|
||||
if chunk.len() >= self.buf().cap() {
|
||||
self.flush(ctx).await?;
|
||||
// do a big write, bypassing `buf`
|
||||
assert_eq!(
|
||||
self.buf
|
||||
.as_ref()
|
||||
.expect("must not use after an error")
|
||||
.pending(),
|
||||
0
|
||||
);
|
||||
let (nwritten, chunk) = self
|
||||
.writer
|
||||
.write_all(FullSlice::must_new(chunk), ctx)
|
||||
.await?;
|
||||
assert_eq!(nwritten, chunk_len);
|
||||
return Ok((nwritten, chunk));
|
||||
}
|
||||
Ok(len)
|
||||
// in-memory copy the < BUFFER_SIZED tail of the chunk
|
||||
assert!(chunk.len() < self.buf().cap());
|
||||
let mut slice = &chunk[..];
|
||||
while !slice.is_empty() {
|
||||
let buf = self.buf.as_mut().expect("must not use after an error");
|
||||
let need = buf.cap() - buf.pending();
|
||||
let have = slice.len();
|
||||
let n = std::cmp::min(need, have);
|
||||
buf.extend_from_slice(&slice[..n]);
|
||||
slice = &slice[n..];
|
||||
if buf.pending() >= buf.cap() {
|
||||
assert_eq!(buf.pending(), buf.cap());
|
||||
self.flush(ctx).await?;
|
||||
}
|
||||
}
|
||||
assert!(slice.is_empty(), "by now we should have drained the chunk");
|
||||
Ok((chunk_len, FullSlice::must_new(chunk)))
|
||||
}
|
||||
|
||||
/// In addition to bytes submitted in this write, also returns a handle that can control the flush behavior.
|
||||
pub(crate) async fn write_buffered_borrowed_controlled(
|
||||
/// Strictly less performant variant of [`Self::write_buffered`] that allows writing borrowed data.
|
||||
///
|
||||
/// It is less performant because we always have to copy the borrowed data into the internal buffer
|
||||
/// before we can do the IO. The [`Self::write_buffered`] can avoid this, which is more performant
|
||||
/// for large writes.
|
||||
pub async fn write_buffered_borrowed(
|
||||
&mut self,
|
||||
mut chunk: &[u8],
|
||||
ctx: &RequestContext,
|
||||
) -> std::io::Result<(usize, Option<FlushControl>)> {
|
||||
) -> std::io::Result<usize> {
|
||||
let chunk_len = chunk.len();
|
||||
let mut control: Option<FlushControl> = None;
|
||||
while !chunk.is_empty() {
|
||||
let buf = self.mutable.as_mut().expect("must not use after an error");
|
||||
let buf = self.buf.as_mut().expect("must not use after an error");
|
||||
let need = buf.cap() - buf.pending();
|
||||
let have = chunk.len();
|
||||
let n = std::cmp::min(need, have);
|
||||
@@ -161,27 +145,26 @@ where
|
||||
chunk = &chunk[n..];
|
||||
if buf.pending() >= buf.cap() {
|
||||
assert_eq!(buf.pending(), buf.cap());
|
||||
if let Some(control) = control.take() {
|
||||
control.release().await;
|
||||
}
|
||||
control = self.flush(ctx).await?;
|
||||
self.flush(ctx).await?;
|
||||
}
|
||||
}
|
||||
Ok((chunk_len, control))
|
||||
Ok(chunk_len)
|
||||
}
|
||||
|
||||
#[must_use = "caller must explcitly check the flush control"]
|
||||
async fn flush(&mut self, _ctx: &RequestContext) -> std::io::Result<Option<FlushControl>> {
|
||||
let buf = self.mutable.take().expect("must not use after an error");
|
||||
async fn flush(&mut self, ctx: &RequestContext) -> std::io::Result<()> {
|
||||
let buf = self.buf.take().expect("must not use after an error");
|
||||
let buf_len = buf.pending();
|
||||
if buf_len == 0 {
|
||||
self.mutable = Some(buf);
|
||||
return Ok(None);
|
||||
self.buf = Some(buf);
|
||||
return Ok(());
|
||||
}
|
||||
let (recycled, flush_control) = self.flush_handle.flush(buf, self.bytes_submitted).await?;
|
||||
self.bytes_submitted += u64::try_from(buf_len).unwrap();
|
||||
self.mutable = Some(recycled);
|
||||
Ok(Some(flush_control))
|
||||
let slice = buf.flush();
|
||||
let (nwritten, slice) = self.writer.write_all(slice, ctx).await?;
|
||||
assert_eq!(nwritten, buf_len);
|
||||
self.buf = Some(Buffer::reuse_after_flush(
|
||||
slice.into_raw_slice().into_inner(),
|
||||
));
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
@@ -209,77 +192,64 @@ pub trait Buffer {
|
||||
fn reuse_after_flush(iobuf: Self::IoBuf) -> Self;
|
||||
}
|
||||
|
||||
impl Buffer for IoBufferMut {
|
||||
type IoBuf = IoBuffer;
|
||||
impl Buffer for BytesMut {
|
||||
type IoBuf = BytesMut;
|
||||
|
||||
#[inline(always)]
|
||||
fn cap(&self) -> usize {
|
||||
self.capacity()
|
||||
}
|
||||
|
||||
fn extend_from_slice(&mut self, other: &[u8]) {
|
||||
if self.len() + other.len() > self.cap() {
|
||||
panic!("Buffer capacity exceeded");
|
||||
}
|
||||
|
||||
IoBufferMut::extend_from_slice(self, other);
|
||||
BytesMut::extend_from_slice(self, other)
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
fn pending(&self) -> usize {
|
||||
self.len()
|
||||
}
|
||||
|
||||
fn flush(self) -> FullSlice<Self::IoBuf> {
|
||||
self.freeze().slice_len()
|
||||
fn flush(self) -> FullSlice<BytesMut> {
|
||||
self.slice_len()
|
||||
}
|
||||
|
||||
/// Caller should make sure that `iobuf` only have one strong reference before invoking this method.
|
||||
fn reuse_after_flush(iobuf: Self::IoBuf) -> Self {
|
||||
let mut recycled = iobuf
|
||||
.into_mut()
|
||||
.expect("buffer should only have one strong reference");
|
||||
recycled.clear();
|
||||
recycled
|
||||
fn reuse_after_flush(mut iobuf: BytesMut) -> Self {
|
||||
iobuf.clear();
|
||||
iobuf
|
||||
}
|
||||
}
|
||||
|
||||
impl OwnedAsyncWriter for Vec<u8> {
|
||||
async fn write_all<Buf: IoBuf + Send>(
|
||||
&mut self,
|
||||
buf: FullSlice<Buf>,
|
||||
_: &RequestContext,
|
||||
) -> std::io::Result<(usize, FullSlice<Buf>)> {
|
||||
self.extend_from_slice(&buf[..]);
|
||||
Ok((buf.len(), buf))
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use std::sync::Mutex;
|
||||
use bytes::BytesMut;
|
||||
|
||||
use super::*;
|
||||
use crate::context::{DownloadBehavior, RequestContext};
|
||||
use crate::task_mgr::TaskKind;
|
||||
|
||||
#[derive(Default, Debug)]
|
||||
#[derive(Default)]
|
||||
struct RecorderWriter {
|
||||
/// record bytes and write offsets.
|
||||
writes: Mutex<Vec<(Vec<u8>, u64)>>,
|
||||
writes: Vec<Vec<u8>>,
|
||||
}
|
||||
|
||||
impl RecorderWriter {
|
||||
/// Gets recorded bytes and write offsets.
|
||||
fn get_writes(&self) -> Vec<Vec<u8>> {
|
||||
self.writes
|
||||
.lock()
|
||||
.unwrap()
|
||||
.iter()
|
||||
.map(|(buf, _)| buf.clone())
|
||||
.collect()
|
||||
}
|
||||
}
|
||||
|
||||
impl OwnedAsyncWriter for RecorderWriter {
|
||||
async fn write_all_at<Buf: IoBufAligned + Send>(
|
||||
&self,
|
||||
async fn write_all<Buf: IoBuf + Send>(
|
||||
&mut self,
|
||||
buf: FullSlice<Buf>,
|
||||
offset: u64,
|
||||
_: &RequestContext,
|
||||
) -> std::io::Result<FullSlice<Buf>> {
|
||||
self.writes
|
||||
.lock()
|
||||
.unwrap()
|
||||
.push((Vec::from(&buf[..]), offset));
|
||||
Ok(buf)
|
||||
) -> std::io::Result<(usize, FullSlice<Buf>)> {
|
||||
self.writes.push(Vec::from(&buf[..]));
|
||||
Ok((buf.len(), buf))
|
||||
}
|
||||
}
|
||||
|
||||
@@ -287,21 +257,71 @@ mod tests {
|
||||
RequestContext::new(TaskKind::UnitTest, DownloadBehavior::Error)
|
||||
}
|
||||
|
||||
macro_rules! write {
|
||||
($writer:ident, $data:literal) => {{
|
||||
$writer
|
||||
.write_buffered(::bytes::Bytes::from_static($data).slice_len(), &test_ctx())
|
||||
.await?;
|
||||
}};
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_write_all_borrowed_always_goes_through_buffer() -> anyhow::Result<()> {
|
||||
async fn test_buffered_writes_only() -> std::io::Result<()> {
|
||||
let recorder = RecorderWriter::default();
|
||||
let mut writer = BufferedWriter::new(recorder, BytesMut::with_capacity(2));
|
||||
write!(writer, b"a");
|
||||
write!(writer, b"b");
|
||||
write!(writer, b"c");
|
||||
write!(writer, b"d");
|
||||
write!(writer, b"e");
|
||||
let recorder = writer.flush_and_into_inner(&test_ctx()).await?;
|
||||
assert_eq!(
|
||||
recorder.writes,
|
||||
vec![Vec::from(b"ab"), Vec::from(b"cd"), Vec::from(b"e")]
|
||||
);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_passthrough_writes_only() -> std::io::Result<()> {
|
||||
let recorder = RecorderWriter::default();
|
||||
let mut writer = BufferedWriter::new(recorder, BytesMut::with_capacity(2));
|
||||
write!(writer, b"abc");
|
||||
write!(writer, b"de");
|
||||
write!(writer, b"");
|
||||
write!(writer, b"fghijk");
|
||||
let recorder = writer.flush_and_into_inner(&test_ctx()).await?;
|
||||
assert_eq!(
|
||||
recorder.writes,
|
||||
vec![Vec::from(b"abc"), Vec::from(b"de"), Vec::from(b"fghijk")]
|
||||
);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_passthrough_write_with_nonempty_buffer() -> std::io::Result<()> {
|
||||
let recorder = RecorderWriter::default();
|
||||
let mut writer = BufferedWriter::new(recorder, BytesMut::with_capacity(2));
|
||||
write!(writer, b"a");
|
||||
write!(writer, b"bc");
|
||||
write!(writer, b"d");
|
||||
write!(writer, b"e");
|
||||
let recorder = writer.flush_and_into_inner(&test_ctx()).await?;
|
||||
assert_eq!(
|
||||
recorder.writes,
|
||||
vec![Vec::from(b"a"), Vec::from(b"bc"), Vec::from(b"de")]
|
||||
);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_write_all_borrowed_always_goes_through_buffer() -> std::io::Result<()> {
|
||||
let ctx = test_ctx();
|
||||
let ctx = &ctx;
|
||||
let recorder = Arc::new(RecorderWriter::default());
|
||||
let gate = utils::sync::gate::Gate::default();
|
||||
let mut writer = BufferedWriter::<_, RecorderWriter>::new(
|
||||
recorder,
|
||||
|| IoBufferMut::with_capacity(2),
|
||||
gate.enter()?,
|
||||
ctx,
|
||||
);
|
||||
let recorder = RecorderWriter::default();
|
||||
let mut writer = BufferedWriter::new(recorder, BytesMut::with_capacity(2));
|
||||
|
||||
writer.write_buffered_borrowed(b"abc", ctx).await?;
|
||||
writer.write_buffered_borrowed(b"", ctx).await?;
|
||||
writer.write_buffered_borrowed(b"d", ctx).await?;
|
||||
writer.write_buffered_borrowed(b"e", ctx).await?;
|
||||
writer.write_buffered_borrowed(b"fg", ctx).await?;
|
||||
@@ -309,9 +329,9 @@ mod tests {
|
||||
writer.write_buffered_borrowed(b"j", ctx).await?;
|
||||
writer.write_buffered_borrowed(b"klmno", ctx).await?;
|
||||
|
||||
let (_, recorder) = writer.flush_and_into_inner(ctx).await?;
|
||||
let recorder = writer.flush_and_into_inner(ctx).await?;
|
||||
assert_eq!(
|
||||
recorder.get_writes(),
|
||||
recorder.writes,
|
||||
{
|
||||
let expect: &[&[u8]] = &[b"ab", b"cd", b"ef", b"gh", b"ij", b"kl", b"mn", b"o"];
|
||||
expect
|
||||
|
||||
@@ -1,314 +0,0 @@
|
||||
use std::sync::Arc;
|
||||
|
||||
use utils::sync::duplex;
|
||||
|
||||
use crate::{
|
||||
context::RequestContext,
|
||||
virtual_file::owned_buffers_io::{io_buf_aligned::IoBufAligned, io_buf_ext::FullSlice},
|
||||
};
|
||||
|
||||
use super::{Buffer, CheapCloneForRead, OwnedAsyncWriter};
|
||||
|
||||
/// A handle to the flush task.
|
||||
pub struct FlushHandle<Buf, W> {
|
||||
inner: Option<FlushHandleInner<Buf, W>>,
|
||||
/// Immutable buffer for serving tail reads.
|
||||
/// `None` if no flush request has been submitted.
|
||||
pub(super) maybe_flushed: Option<FullSlice<Buf>>,
|
||||
}
|
||||
|
||||
pub struct FlushHandleInner<Buf, W> {
|
||||
/// A bi-directional channel that sends (buffer, offset) for writes,
|
||||
/// and receives recyled buffer.
|
||||
channel: duplex::mpsc::Duplex<FlushRequest<Buf>, FullSlice<Buf>>,
|
||||
/// Join handle for the background flush task.
|
||||
join_handle: tokio::task::JoinHandle<std::io::Result<Arc<W>>>,
|
||||
}
|
||||
|
||||
struct FlushRequest<Buf> {
|
||||
slice: FullSlice<Buf>,
|
||||
offset: u64,
|
||||
#[cfg(test)]
|
||||
ready_to_flush_rx: tokio::sync::oneshot::Receiver<()>,
|
||||
#[cfg(test)]
|
||||
done_flush_tx: tokio::sync::oneshot::Sender<()>,
|
||||
}
|
||||
|
||||
/// Constructs a request and a control object for a new flush operation.
|
||||
#[cfg(not(test))]
|
||||
fn new_flush_op<Buf>(slice: FullSlice<Buf>, offset: u64) -> (FlushRequest<Buf>, FlushControl) {
|
||||
let request = FlushRequest { slice, offset };
|
||||
let control = FlushControl::untracked();
|
||||
|
||||
(request, control)
|
||||
}
|
||||
|
||||
/// Constructs a request and a control object for a new flush operation.
|
||||
#[cfg(test)]
|
||||
fn new_flush_op<Buf>(slice: FullSlice<Buf>, offset: u64) -> (FlushRequest<Buf>, FlushControl) {
|
||||
let (ready_to_flush_tx, ready_to_flush_rx) = tokio::sync::oneshot::channel();
|
||||
let (done_flush_tx, done_flush_rx) = tokio::sync::oneshot::channel();
|
||||
let control = FlushControl::not_started(ready_to_flush_tx, done_flush_rx);
|
||||
|
||||
let request = FlushRequest {
|
||||
slice,
|
||||
offset,
|
||||
ready_to_flush_rx,
|
||||
done_flush_tx,
|
||||
};
|
||||
(request, control)
|
||||
}
|
||||
|
||||
/// A handle to a `FlushRequest` that allows unit tests precise control over flush behavior.
|
||||
#[cfg(test)]
|
||||
pub(crate) struct FlushControl {
|
||||
not_started: FlushNotStarted,
|
||||
}
|
||||
|
||||
#[cfg(not(test))]
|
||||
pub(crate) struct FlushControl;
|
||||
|
||||
impl FlushControl {
|
||||
#[cfg(test)]
|
||||
fn not_started(
|
||||
ready_to_flush_tx: tokio::sync::oneshot::Sender<()>,
|
||||
done_flush_rx: tokio::sync::oneshot::Receiver<()>,
|
||||
) -> Self {
|
||||
FlushControl {
|
||||
not_started: FlushNotStarted {
|
||||
ready_to_flush_tx,
|
||||
done_flush_rx,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(not(test))]
|
||||
fn untracked() -> Self {
|
||||
FlushControl
|
||||
}
|
||||
|
||||
/// In tests, turn flush control into a not started state.
|
||||
#[cfg(test)]
|
||||
pub(crate) fn into_not_started(self) -> FlushNotStarted {
|
||||
self.not_started
|
||||
}
|
||||
|
||||
/// Release control to the submitted buffer.
|
||||
///
|
||||
/// In `cfg(test)` environment, the buffer is guranteed to be flushed to disk after [`FlushControl::release`] is finishes execution.
|
||||
pub async fn release(self) {
|
||||
#[cfg(test)]
|
||||
{
|
||||
self.not_started
|
||||
.ready_to_flush()
|
||||
.wait_until_flush_is_done()
|
||||
.await;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<Buf, W> FlushHandle<Buf, W>
|
||||
where
|
||||
Buf: IoBufAligned + Send + Sync + CheapCloneForRead,
|
||||
W: OwnedAsyncWriter + Send + Sync + 'static + std::fmt::Debug,
|
||||
{
|
||||
/// Spawns a new background flush task and obtains a handle.
|
||||
///
|
||||
/// Note: The background task so we do not need to explicitly maintain a queue of buffers.
|
||||
pub fn spawn_new<B>(
|
||||
file: Arc<W>,
|
||||
buf: B,
|
||||
gate_guard: utils::sync::gate::GateGuard,
|
||||
ctx: RequestContext,
|
||||
) -> Self
|
||||
where
|
||||
B: Buffer<IoBuf = Buf> + Send + 'static,
|
||||
{
|
||||
// It is fine to buffer up to only 1 message. We only 1 message in-flight at a time.
|
||||
let (front, back) = duplex::mpsc::channel(1);
|
||||
|
||||
let join_handle = tokio::spawn(async move {
|
||||
FlushBackgroundTask::new(back, file, gate_guard, ctx)
|
||||
.run(buf.flush())
|
||||
.await
|
||||
});
|
||||
|
||||
FlushHandle {
|
||||
inner: Some(FlushHandleInner {
|
||||
channel: front,
|
||||
join_handle,
|
||||
}),
|
||||
maybe_flushed: None,
|
||||
}
|
||||
}
|
||||
|
||||
/// Submits a buffer to be flushed in the background task.
|
||||
/// Returns a buffer that completed flushing for re-use, length reset to 0, capacity unchanged.
|
||||
/// If `save_buf_for_read` is true, then we save the buffer in `Self::maybe_flushed`, otherwise
|
||||
/// clear `maybe_flushed`.
|
||||
pub async fn flush<B>(&mut self, buf: B, offset: u64) -> std::io::Result<(B, FlushControl)>
|
||||
where
|
||||
B: Buffer<IoBuf = Buf> + Send + 'static,
|
||||
{
|
||||
let slice = buf.flush();
|
||||
|
||||
// Saves a buffer for read while flushing. This also removes reference to the old buffer.
|
||||
self.maybe_flushed = Some(slice.cheap_clone());
|
||||
|
||||
let (request, flush_control) = new_flush_op(slice, offset);
|
||||
|
||||
// Submits the buffer to the background task.
|
||||
let submit = self.inner_mut().channel.send(request).await;
|
||||
if submit.is_err() {
|
||||
return self.handle_error().await;
|
||||
}
|
||||
|
||||
// Wait for an available buffer from the background flush task.
|
||||
// This is the BACKPRESSURE mechanism: if the flush task can't keep up,
|
||||
// then the write path will eventually wait for it here.
|
||||
let Some(recycled) = self.inner_mut().channel.recv().await else {
|
||||
return self.handle_error().await;
|
||||
};
|
||||
|
||||
// The only other place that could hold a reference to the recycled buffer
|
||||
// is in `Self::maybe_flushed`, but we have already replace it with the new buffer.
|
||||
let recycled = Buffer::reuse_after_flush(recycled.into_raw_slice().into_inner());
|
||||
Ok((recycled, flush_control))
|
||||
}
|
||||
|
||||
async fn handle_error<T>(&mut self) -> std::io::Result<T> {
|
||||
Err(self
|
||||
.shutdown()
|
||||
.await
|
||||
.expect_err("flush task only disconnects duplex if it exits with an error"))
|
||||
}
|
||||
|
||||
/// Cleans up the channel, join the flush task.
|
||||
pub async fn shutdown(&mut self) -> std::io::Result<Arc<W>> {
|
||||
let handle = self
|
||||
.inner
|
||||
.take()
|
||||
.expect("must not use after we returned an error");
|
||||
drop(handle.channel.tx);
|
||||
handle.join_handle.await.unwrap()
|
||||
}
|
||||
|
||||
/// Gets a mutable reference to the inner handle. Panics if [`Self::inner`] is `None`.
|
||||
/// This only happens if the handle is used after an error.
|
||||
fn inner_mut(&mut self) -> &mut FlushHandleInner<Buf, W> {
|
||||
self.inner
|
||||
.as_mut()
|
||||
.expect("must not use after we returned an error")
|
||||
}
|
||||
}
|
||||
|
||||
/// A background task for flushing data to disk.
|
||||
pub struct FlushBackgroundTask<Buf, W> {
|
||||
/// A bi-directional channel that receives (buffer, offset) for writes,
|
||||
/// and send back recycled buffer.
|
||||
channel: duplex::mpsc::Duplex<FullSlice<Buf>, FlushRequest<Buf>>,
|
||||
/// A writter for persisting data to disk.
|
||||
writer: Arc<W>,
|
||||
ctx: RequestContext,
|
||||
/// Prevent timeline from shuting down until the flush background task finishes flushing all remaining buffers to disk.
|
||||
_gate_guard: utils::sync::gate::GateGuard,
|
||||
}
|
||||
|
||||
impl<Buf, W> FlushBackgroundTask<Buf, W>
|
||||
where
|
||||
Buf: IoBufAligned + Send + Sync,
|
||||
W: OwnedAsyncWriter + Sync + 'static,
|
||||
{
|
||||
/// Creates a new background flush task.
|
||||
fn new(
|
||||
channel: duplex::mpsc::Duplex<FullSlice<Buf>, FlushRequest<Buf>>,
|
||||
file: Arc<W>,
|
||||
gate_guard: utils::sync::gate::GateGuard,
|
||||
ctx: RequestContext,
|
||||
) -> Self {
|
||||
FlushBackgroundTask {
|
||||
channel,
|
||||
writer: file,
|
||||
_gate_guard: gate_guard,
|
||||
ctx,
|
||||
}
|
||||
}
|
||||
|
||||
/// Runs the background flush task.
|
||||
/// The passed in slice is immediately sent back to the flush handle through the duplex channel.
|
||||
async fn run(mut self, slice: FullSlice<Buf>) -> std::io::Result<Arc<W>> {
|
||||
// Sends the extra buffer back to the handle.
|
||||
self.channel.send(slice).await.map_err(|_| {
|
||||
std::io::Error::new(std::io::ErrorKind::BrokenPipe, "flush handle closed early")
|
||||
})?;
|
||||
|
||||
// Exit condition: channel is closed and there is no remaining buffer to be flushed
|
||||
while let Some(request) = self.channel.recv().await {
|
||||
#[cfg(test)]
|
||||
{
|
||||
// In test, wait for control to signal that we are ready to flush.
|
||||
if request.ready_to_flush_rx.await.is_err() {
|
||||
tracing::debug!("control dropped");
|
||||
}
|
||||
}
|
||||
|
||||
// Write slice to disk at `offset`.
|
||||
let slice = self
|
||||
.writer
|
||||
.write_all_at(request.slice, request.offset, &self.ctx)
|
||||
.await?;
|
||||
|
||||
#[cfg(test)]
|
||||
{
|
||||
// In test, tell control we are done flushing buffer.
|
||||
if request.done_flush_tx.send(()).is_err() {
|
||||
tracing::debug!("control dropped");
|
||||
}
|
||||
}
|
||||
|
||||
// Sends the buffer back to the handle for reuse. The handle is in charged of cleaning the buffer.
|
||||
if self.channel.send(slice).await.is_err() {
|
||||
// Although channel is closed. Still need to finish flushing the remaining buffers.
|
||||
continue;
|
||||
}
|
||||
}
|
||||
|
||||
Ok(self.writer)
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
pub(crate) struct FlushNotStarted {
|
||||
ready_to_flush_tx: tokio::sync::oneshot::Sender<()>,
|
||||
done_flush_rx: tokio::sync::oneshot::Receiver<()>,
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
pub(crate) struct FlushInProgress {
|
||||
done_flush_rx: tokio::sync::oneshot::Receiver<()>,
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
pub(crate) struct FlushDone;
|
||||
|
||||
#[cfg(test)]
|
||||
impl FlushNotStarted {
|
||||
/// Signals the background task the buffer is ready to flush to disk.
|
||||
pub fn ready_to_flush(self) -> FlushInProgress {
|
||||
self.ready_to_flush_tx
|
||||
.send(())
|
||||
.map(|_| FlushInProgress {
|
||||
done_flush_rx: self.done_flush_rx,
|
||||
})
|
||||
.unwrap()
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
impl FlushInProgress {
|
||||
/// Waits until background flush is done.
|
||||
pub async fn wait_until_flush_is_done(self) -> FlushDone {
|
||||
self.done_flush_rx.await.unwrap();
|
||||
FlushDone
|
||||
}
|
||||
}
|
||||
@@ -1392,10 +1392,6 @@ impl WalIngest {
|
||||
img: Bytes,
|
||||
ctx: &RequestContext,
|
||||
) -> Result<()> {
|
||||
if !self.shard.is_shard_zero() {
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
self.handle_slru_extend(modification, kind, segno, blknum, ctx)
|
||||
.await?;
|
||||
modification.put_slru_page_image(kind, segno, blknum, img)?;
|
||||
|
||||
@@ -12,15 +12,21 @@
|
||||
|
||||
#include <curl/curl.h>
|
||||
|
||||
#include "access/xact.h"
|
||||
#include "utils/guc.h"
|
||||
#include "tcop/utility.h"
|
||||
|
||||
#include "extension_server.h"
|
||||
#include "neon_utils.h"
|
||||
|
||||
static int extension_server_port = 0;
|
||||
static int extension_server_port = 0;
|
||||
|
||||
static download_extension_file_hook_type prev_download_extension_file_hook = NULL;
|
||||
|
||||
static ProcessUtility_hook_type PreviousProcessUtilityHook = NULL;
|
||||
|
||||
static bool extension_ddl_occured = false;
|
||||
|
||||
/*
|
||||
* to download all SQL (and data) files for an extension:
|
||||
* curl -X POST http://localhost:8080/extension_server/postgis
|
||||
@@ -74,9 +80,154 @@ neon_download_extension_file_http(const char *filename, bool is_library)
|
||||
return ret;
|
||||
}
|
||||
|
||||
|
||||
// Handle extension DDL: we need this for monitoring of installed extensions.
|
||||
// General solution is hard, because extensions can be installed in many ways,
|
||||
// i.e. sometimes as a cascade operations.
|
||||
//
|
||||
// Also, we don't have enough information in the statement itself,
|
||||
// i.e. extension version is not always present and retrieved from the control file
|
||||
// at a later stage.
|
||||
//
|
||||
// So, just remember the fact of the extension DDL and send it to compute_ctl
|
||||
// on commit.
|
||||
static void
|
||||
NeonExtensionProcessUtility(
|
||||
PlannedStmt *pstmt,
|
||||
const char *queryString,
|
||||
bool readOnlyTree,
|
||||
ProcessUtilityContext context,
|
||||
ParamListInfo params,
|
||||
QueryEnvironment *queryEnv,
|
||||
DestReceiver *dest,
|
||||
QueryCompletion *qc)
|
||||
{
|
||||
Node *parseTree = pstmt->utilityStmt;
|
||||
|
||||
switch (nodeTag(parseTree))
|
||||
{
|
||||
case T_CreateExtensionStmt:
|
||||
case T_AlterExtensionStmt:
|
||||
extension_ddl_occured = true;
|
||||
elog(LOG, "Extension DDL occured");
|
||||
break;
|
||||
case T_DropStmt:
|
||||
{
|
||||
switch (((DropStmt *) parseTree)->removeType)
|
||||
{
|
||||
case OBJECT_EXTENSION:
|
||||
extension_ddl_occured = true;
|
||||
elog(LOG, "Extension DDL occured");
|
||||
break;
|
||||
default:
|
||||
break;
|
||||
}
|
||||
}
|
||||
break;
|
||||
default:
|
||||
break;
|
||||
}
|
||||
|
||||
if (PreviousProcessUtilityHook)
|
||||
{
|
||||
PreviousProcessUtilityHook(
|
||||
pstmt,
|
||||
queryString,
|
||||
readOnlyTree,
|
||||
context,
|
||||
params,
|
||||
queryEnv,
|
||||
dest,
|
||||
qc);
|
||||
}
|
||||
else
|
||||
{
|
||||
standard_ProcessUtility(
|
||||
pstmt,
|
||||
queryString,
|
||||
readOnlyTree,
|
||||
context,
|
||||
params,
|
||||
queryEnv,
|
||||
dest,
|
||||
qc);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
|
||||
static bool
|
||||
neon_send_extension_ddl_event_http()
|
||||
{
|
||||
static CURL *handle = NULL;
|
||||
|
||||
CURLcode res;
|
||||
char *compute_ctl_url;
|
||||
bool ret = false;
|
||||
|
||||
if (handle == NULL)
|
||||
{
|
||||
handle = alloc_curl_handle();
|
||||
}
|
||||
|
||||
compute_ctl_url = psprintf("http://localhost:%d/installed_extensions",
|
||||
extension_server_port);
|
||||
|
||||
elog(LOG, "Sending ddl event to compute_ctl: %s", compute_ctl_url);
|
||||
|
||||
curl_easy_setopt(handle, CURLOPT_URL, compute_ctl_url);
|
||||
|
||||
// Use HEAD request without payload, because this is just a notification.
|
||||
//
|
||||
// This is probably not the best API design, but I didn't want to introduce
|
||||
// new endpoint for this. Suggestions are welcome.
|
||||
curl_easy_setopt(handle, CURLOPT_NOBODY, 1L);
|
||||
|
||||
/* Perform the request, res will get the return code */
|
||||
res = curl_easy_perform(handle);
|
||||
|
||||
/* Check for errors */
|
||||
if (res != CURLE_OK)
|
||||
{
|
||||
/*
|
||||
* Don't error here because this is just a monitoring feature.
|
||||
*/
|
||||
elog(WARNING, "neon_send_extension_ddl_event_http failed: %s\n", curl_easy_strerror(res));
|
||||
}
|
||||
|
||||
return ret;
|
||||
}
|
||||
|
||||
static void
|
||||
NeonExtensionXactCallback(XactEvent event, void *arg)
|
||||
{
|
||||
elog(LOG, "NeonExtensionXactCallback: %d", event);
|
||||
/* the handler on the compute_ctl side must be non-blocking
|
||||
* otherwise, the compute_ctl won't see the data that is not yet committed.
|
||||
* There is still a chance of the race, because the data becomes visible only after XACT_EVENT_COMMIT
|
||||
* callback is called. We assume that this is not critical for the monitoring feature and expect that
|
||||
* compute_ctl will eventually see the data.
|
||||
*/
|
||||
if (event == XACT_EVENT_COMMIT || event == XACT_EVENT_PARALLEL_COMMIT)
|
||||
{
|
||||
if (extension_ddl_occured)
|
||||
{
|
||||
elog(LOG, "Sending extension DDL event to compute_ctl");
|
||||
neon_send_extension_ddl_event_http();
|
||||
extension_ddl_occured = false;
|
||||
}
|
||||
elog(LOG, "no extension DDL");
|
||||
}
|
||||
}
|
||||
|
||||
void
|
||||
pg_init_extension_server()
|
||||
{
|
||||
|
||||
PreviousProcessUtilityHook = ProcessUtility_hook;
|
||||
ProcessUtility_hook = NeonExtensionProcessUtility;
|
||||
RegisterXactCallback(NeonExtensionXactCallback, NULL);
|
||||
|
||||
/* Port to connect to compute_ctl on localhost */
|
||||
/* to request extension files. */
|
||||
DefineCustomIntVariable("neon.extension_server_port",
|
||||
|
||||
@@ -6,7 +6,7 @@ license.workspace = true
|
||||
|
||||
[features]
|
||||
default = []
|
||||
testing = ["dep:tokio-postgres"]
|
||||
testing = []
|
||||
|
||||
[dependencies]
|
||||
ahash.workspace = true
|
||||
@@ -55,7 +55,6 @@ parquet.workspace = true
|
||||
parquet_derive.workspace = true
|
||||
pin-project-lite.workspace = true
|
||||
postgres_backend.workspace = true
|
||||
postgres-client = { package = "tokio-postgres2", path = "../libs/proxy/tokio-postgres2" }
|
||||
postgres-protocol = { package = "postgres-protocol2", path = "../libs/proxy/postgres-protocol2" }
|
||||
pq_proto.workspace = true
|
||||
prometheus.workspace = true
|
||||
@@ -82,7 +81,7 @@ subtle.workspace = true
|
||||
thiserror.workspace = true
|
||||
tikv-jemallocator.workspace = true
|
||||
tikv-jemalloc-ctl = { workspace = true, features = ["use_std"] }
|
||||
tokio-postgres = { workspace = true, optional = true }
|
||||
tokio-postgres = { package = "tokio-postgres2", path = "../libs/proxy/tokio-postgres2" }
|
||||
tokio-rustls.workspace = true
|
||||
tokio-util.workspace = true
|
||||
tokio = { workspace = true, features = ["signal"] }
|
||||
@@ -120,4 +119,3 @@ rcgen.workspace = true
|
||||
rstest.workspace = true
|
||||
walkdir.workspace = true
|
||||
rand_distr = "0.4"
|
||||
tokio-postgres.workspace = true
|
||||
|
||||
@@ -66,7 +66,7 @@ pub(super) async fn authenticate(
|
||||
|
||||
Ok(ComputeCredentials {
|
||||
info: creds,
|
||||
keys: ComputeCredentialKeys::AuthKeys(postgres_client::config::AuthKeys::ScramSha256(
|
||||
keys: ComputeCredentialKeys::AuthKeys(tokio_postgres::config::AuthKeys::ScramSha256(
|
||||
scram_keys,
|
||||
)),
|
||||
})
|
||||
|
||||
@@ -1,8 +1,8 @@
|
||||
use async_trait::async_trait;
|
||||
use postgres_client::config::SslMode;
|
||||
use pq_proto::BeMessage as Be;
|
||||
use thiserror::Error;
|
||||
use tokio::io::{AsyncRead, AsyncWrite};
|
||||
use tokio_postgres::config::SslMode;
|
||||
use tracing::{info, info_span};
|
||||
|
||||
use super::ComputeCredentialKeys;
|
||||
@@ -49,19 +49,13 @@ impl ReportableError for ConsoleRedirectError {
|
||||
}
|
||||
}
|
||||
|
||||
fn hello_message(
|
||||
redirect_uri: &reqwest::Url,
|
||||
session_id: &str,
|
||||
duration: std::time::Duration,
|
||||
) -> String {
|
||||
let formatted_duration = humantime::format_duration(duration).to_string();
|
||||
fn hello_message(redirect_uri: &reqwest::Url, session_id: &str) -> String {
|
||||
format!(
|
||||
concat![
|
||||
"Welcome to Neon!\n",
|
||||
"Authenticate by visiting (will expire in {duration}):\n",
|
||||
"Authenticate by visiting:\n",
|
||||
" {redirect_uri}{session_id}\n\n",
|
||||
],
|
||||
duration = formatted_duration,
|
||||
redirect_uri = redirect_uri,
|
||||
session_id = session_id,
|
||||
)
|
||||
@@ -124,11 +118,7 @@ async fn authenticate(
|
||||
};
|
||||
|
||||
let span = info_span!("console_redirect", psql_session_id = &psql_session_id);
|
||||
let greeting = hello_message(
|
||||
link_uri,
|
||||
&psql_session_id,
|
||||
auth_config.console_redirect_confirmation_timeout,
|
||||
);
|
||||
let greeting = hello_message(link_uri, &psql_session_id);
|
||||
|
||||
// Give user a URL to spawn a new database.
|
||||
info!(parent: &span, "sending the auth URL to the user");
|
||||
@@ -161,8 +151,12 @@ async fn authenticate(
|
||||
|
||||
// This config should be self-contained, because we won't
|
||||
// take username or dbname from client's startup message.
|
||||
let mut config = compute::ConnCfg::new(db_info.host.to_string(), db_info.port);
|
||||
config.dbname(&db_info.dbname).user(&db_info.user);
|
||||
let mut config = compute::ConnCfg::new();
|
||||
config
|
||||
.host(&db_info.host)
|
||||
.port(db_info.port)
|
||||
.dbname(&db_info.dbname)
|
||||
.user(&db_info.user);
|
||||
|
||||
ctx.set_dbname(db_info.dbname.into());
|
||||
ctx.set_user(db_info.user.into());
|
||||
|
||||
@@ -29,7 +29,12 @@ impl LocalBackend {
|
||||
api: http::Endpoint::new(compute_ctl, http::new_client()),
|
||||
},
|
||||
node_info: NodeInfo {
|
||||
config: ConnCfg::new(postgres_addr.ip().to_string(), postgres_addr.port()),
|
||||
config: {
|
||||
let mut cfg = ConnCfg::new();
|
||||
cfg.host(&postgres_addr.ip().to_string());
|
||||
cfg.port(postgres_addr.port());
|
||||
cfg
|
||||
},
|
||||
// TODO(conrad): make this better reflect compute info rather than endpoint info.
|
||||
aux: MetricsAuxInfo {
|
||||
endpoint_id: EndpointIdTag::get_interner().get_or_intern("local"),
|
||||
|
||||
@@ -11,8 +11,8 @@ pub use console_redirect::ConsoleRedirectBackend;
|
||||
pub(crate) use console_redirect::ConsoleRedirectError;
|
||||
use ipnet::{Ipv4Net, Ipv6Net};
|
||||
use local::LocalBackend;
|
||||
use postgres_client::config::AuthKeys;
|
||||
use tokio::io::{AsyncRead, AsyncWrite};
|
||||
use tokio_postgres::config::AuthKeys;
|
||||
use tracing::{debug, info, warn};
|
||||
|
||||
use crate::auth::credentials::check_peer_addr_is_in_list;
|
||||
|
||||
@@ -227,7 +227,7 @@ pub(crate) async fn validate_password_and_exchange(
|
||||
};
|
||||
|
||||
Ok(sasl::Outcome::Success(ComputeCredentialKeys::AuthKeys(
|
||||
postgres_client::config::AuthKeys::ScramSha256(keys),
|
||||
tokio_postgres::config::AuthKeys::ScramSha256(keys),
|
||||
)))
|
||||
}
|
||||
}
|
||||
|
||||
@@ -3,11 +3,11 @@ use std::sync::Arc;
|
||||
|
||||
use dashmap::DashMap;
|
||||
use ipnet::{IpNet, Ipv4Net, Ipv6Net};
|
||||
use postgres_client::{CancelToken, NoTls};
|
||||
use pq_proto::CancelKeyData;
|
||||
use thiserror::Error;
|
||||
use tokio::net::TcpStream;
|
||||
use tokio::sync::Mutex;
|
||||
use tokio_postgres::{CancelToken, NoTls};
|
||||
use tracing::{debug, info};
|
||||
use uuid::Uuid;
|
||||
|
||||
@@ -44,7 +44,7 @@ pub(crate) enum CancelError {
|
||||
IO(#[from] std::io::Error),
|
||||
|
||||
#[error("{0}")]
|
||||
Postgres(#[from] postgres_client::Error),
|
||||
Postgres(#[from] tokio_postgres::Error),
|
||||
|
||||
#[error("rate limit exceeded")]
|
||||
RateLimit,
|
||||
@@ -70,12 +70,11 @@ impl ReportableError for CancelError {
|
||||
impl<P: CancellationPublisher> CancellationHandler<P> {
|
||||
/// Run async action within an ephemeral session identified by [`CancelKeyData`].
|
||||
pub(crate) fn get_session(self: Arc<Self>) -> Session<P> {
|
||||
// we intentionally generate a random "backend pid" and "secret key" here.
|
||||
// we use the corresponding u64 as an identifier for the
|
||||
// actual endpoint+pid+secret for postgres/pgbouncer.
|
||||
//
|
||||
// if we forwarded the backend_pid from postgres to the client, there would be a lot
|
||||
// of overlap between our computes as most pids are small (~100).
|
||||
// HACK: We'd rather get the real backend_pid but tokio_postgres doesn't
|
||||
// expose it and we don't want to do another roundtrip to query
|
||||
// for it. The client will be able to notice that this is not the
|
||||
// actual backend_pid, but backend_pid is not used for anything
|
||||
// so it doesn't matter.
|
||||
let key = loop {
|
||||
let key = rand::random();
|
||||
|
||||
|
||||
@@ -6,8 +6,6 @@ use std::time::Duration;
|
||||
use futures::{FutureExt, TryFutureExt};
|
||||
use itertools::Itertools;
|
||||
use once_cell::sync::OnceCell;
|
||||
use postgres_client::tls::MakeTlsConnect;
|
||||
use postgres_client::{CancelToken, RawConnection};
|
||||
use postgres_protocol::message::backend::NoticeResponseBody;
|
||||
use pq_proto::StartupMessageParams;
|
||||
use rustls::client::danger::ServerCertVerifier;
|
||||
@@ -15,6 +13,8 @@ use rustls::crypto::ring;
|
||||
use rustls::pki_types::InvalidDnsNameError;
|
||||
use thiserror::Error;
|
||||
use tokio::net::TcpStream;
|
||||
use tokio_postgres::tls::MakeTlsConnect;
|
||||
use tokio_postgres::{CancelToken, RawConnection};
|
||||
use tracing::{debug, error, info, warn};
|
||||
|
||||
use crate::auth::parse_endpoint_param;
|
||||
@@ -34,9 +34,9 @@ pub const COULD_NOT_CONNECT: &str = "Couldn't connect to compute node";
|
||||
#[derive(Debug, Error)]
|
||||
pub(crate) enum ConnectionError {
|
||||
/// This error doesn't seem to reveal any secrets; for instance,
|
||||
/// `postgres_client::error::Kind` doesn't contain ip addresses and such.
|
||||
/// `tokio_postgres::error::Kind` doesn't contain ip addresses and such.
|
||||
#[error("{COULD_NOT_CONNECT}: {0}")]
|
||||
Postgres(#[from] postgres_client::Error),
|
||||
Postgres(#[from] tokio_postgres::Error),
|
||||
|
||||
#[error("{COULD_NOT_CONNECT}: {0}")]
|
||||
CouldNotConnect(#[from] io::Error),
|
||||
@@ -99,18 +99,18 @@ impl ReportableError for ConnectionError {
|
||||
}
|
||||
|
||||
/// A pair of `ClientKey` & `ServerKey` for `SCRAM-SHA-256`.
|
||||
pub(crate) type ScramKeys = postgres_client::config::ScramKeys<32>;
|
||||
pub(crate) type ScramKeys = tokio_postgres::config::ScramKeys<32>;
|
||||
|
||||
/// A config for establishing a connection to compute node.
|
||||
/// Eventually, `postgres_client` will be replaced with something better.
|
||||
/// Eventually, `tokio_postgres` will be replaced with something better.
|
||||
/// Newtype allows us to implement methods on top of it.
|
||||
#[derive(Clone)]
|
||||
pub(crate) struct ConnCfg(Box<postgres_client::Config>);
|
||||
#[derive(Clone, Default)]
|
||||
pub(crate) struct ConnCfg(Box<tokio_postgres::Config>);
|
||||
|
||||
/// Creation and initialization routines.
|
||||
impl ConnCfg {
|
||||
pub(crate) fn new(host: String, port: u16) -> Self {
|
||||
Self(Box::new(postgres_client::Config::new(host, port)))
|
||||
pub(crate) fn new() -> Self {
|
||||
Self::default()
|
||||
}
|
||||
|
||||
/// Reuse password or auth keys from the other config.
|
||||
@@ -124,49 +124,65 @@ impl ConnCfg {
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn get_host(&self) -> Host {
|
||||
match self.0.get_host() {
|
||||
postgres_client::config::Host::Tcp(s) => s.into(),
|
||||
pub(crate) fn get_host(&self) -> Result<Host, WakeComputeError> {
|
||||
match self.0.get_hosts() {
|
||||
[tokio_postgres::config::Host::Tcp(s)] => Ok(s.into()),
|
||||
// we should not have multiple address or unix addresses.
|
||||
_ => Err(WakeComputeError::BadComputeAddress(
|
||||
"invalid compute address".into(),
|
||||
)),
|
||||
}
|
||||
}
|
||||
|
||||
/// Apply startup message params to the connection config.
|
||||
pub(crate) fn set_startup_params(
|
||||
&mut self,
|
||||
params: &StartupMessageParams,
|
||||
arbitrary_params: bool,
|
||||
) {
|
||||
if !arbitrary_params {
|
||||
self.set_param("client_encoding", "UTF8");
|
||||
pub(crate) fn set_startup_params(&mut self, params: &StartupMessageParams) {
|
||||
// Only set `user` if it's not present in the config.
|
||||
// Console redirect auth flow takes username from the console's response.
|
||||
if let (None, Some(user)) = (self.get_user(), params.get("user")) {
|
||||
self.user(user);
|
||||
}
|
||||
for (k, v) in params.iter() {
|
||||
match k {
|
||||
// Only set `user` if it's not present in the config.
|
||||
// Console redirect auth flow takes username from the console's response.
|
||||
"user" if self.user_is_set() => continue,
|
||||
"database" if self.db_is_set() => continue,
|
||||
"options" => {
|
||||
if let Some(options) = filtered_options(v) {
|
||||
self.set_param(k, &options);
|
||||
}
|
||||
}
|
||||
"user" | "database" | "application_name" | "replication" => {
|
||||
self.set_param(k, v);
|
||||
}
|
||||
|
||||
// if we allow arbitrary params, then we forward them through.
|
||||
// this is a flag for a period of backwards compatibility
|
||||
k if arbitrary_params => {
|
||||
self.set_param(k, v);
|
||||
// Only set `dbname` if it's not present in the config.
|
||||
// Console redirect auth flow takes dbname from the console's response.
|
||||
if let (None, Some(dbname)) = (self.get_dbname(), params.get("database")) {
|
||||
self.dbname(dbname);
|
||||
}
|
||||
|
||||
// Don't add `options` if they were only used for specifying a project.
|
||||
// Connection pools don't support `options`, because they affect backend startup.
|
||||
if let Some(options) = filtered_options(params) {
|
||||
self.options(&options);
|
||||
}
|
||||
|
||||
if let Some(app_name) = params.get("application_name") {
|
||||
self.application_name(app_name);
|
||||
}
|
||||
|
||||
// TODO: This is especially ugly...
|
||||
if let Some(replication) = params.get("replication") {
|
||||
use tokio_postgres::config::ReplicationMode;
|
||||
match replication {
|
||||
"true" | "on" | "yes" | "1" => {
|
||||
self.replication_mode(ReplicationMode::Physical);
|
||||
}
|
||||
_ => {}
|
||||
"database" => {
|
||||
self.replication_mode(ReplicationMode::Logical);
|
||||
}
|
||||
_other => {}
|
||||
}
|
||||
}
|
||||
|
||||
// TODO: extend the list of the forwarded startup parameters.
|
||||
// Currently, tokio-postgres doesn't allow us to pass
|
||||
// arbitrary parameters, but the ones above are a good start.
|
||||
//
|
||||
// This and the reverse params problem can be better addressed
|
||||
// in a bespoke connection machinery (a new library for that sake).
|
||||
}
|
||||
}
|
||||
|
||||
impl std::ops::Deref for ConnCfg {
|
||||
type Target = postgres_client::Config;
|
||||
type Target = tokio_postgres::Config;
|
||||
|
||||
fn deref(&self) -> &Self::Target {
|
||||
&self.0
|
||||
@@ -183,7 +199,7 @@ impl std::ops::DerefMut for ConnCfg {
|
||||
impl ConnCfg {
|
||||
/// Establish a raw TCP connection to the compute node.
|
||||
async fn connect_raw(&self, timeout: Duration) -> io::Result<(SocketAddr, TcpStream, &str)> {
|
||||
use postgres_client::config::Host;
|
||||
use tokio_postgres::config::Host;
|
||||
|
||||
// wrap TcpStream::connect with timeout
|
||||
let connect_with_timeout = |host, port| {
|
||||
@@ -208,23 +224,46 @@ impl ConnCfg {
|
||||
})
|
||||
};
|
||||
|
||||
// We can't reuse connection establishing logic from `postgres_client` here,
|
||||
// We can't reuse connection establishing logic from `tokio_postgres` here,
|
||||
// because it has no means for extracting the underlying socket which we
|
||||
// require for our business.
|
||||
let port = self.0.get_port();
|
||||
let host = self.0.get_host();
|
||||
let mut connection_error = None;
|
||||
let ports = self.0.get_ports();
|
||||
let hosts = self.0.get_hosts();
|
||||
// the ports array is supposed to have 0 entries, 1 entry, or as many entries as in the hosts array
|
||||
if ports.len() > 1 && ports.len() != hosts.len() {
|
||||
return Err(io::Error::new(
|
||||
io::ErrorKind::Other,
|
||||
format!(
|
||||
"bad compute config, \
|
||||
ports and hosts entries' count does not match: {:?}",
|
||||
self.0
|
||||
),
|
||||
));
|
||||
}
|
||||
|
||||
let host = match host {
|
||||
Host::Tcp(host) => host.as_str(),
|
||||
};
|
||||
for (i, host) in hosts.iter().enumerate() {
|
||||
let port = ports.get(i).or_else(|| ports.first()).unwrap_or(&5432);
|
||||
let host = match host {
|
||||
Host::Tcp(host) => host.as_str(),
|
||||
};
|
||||
|
||||
match connect_once(host, port).await {
|
||||
Ok((sockaddr, stream)) => Ok((sockaddr, stream, host)),
|
||||
Err(err) => {
|
||||
warn!("couldn't connect to compute node at {host}:{port}: {err}");
|
||||
Err(err)
|
||||
match connect_once(host, *port).await {
|
||||
Ok((sockaddr, stream)) => return Ok((sockaddr, stream, host)),
|
||||
Err(err) => {
|
||||
// We can't throw an error here, as there might be more hosts to try.
|
||||
warn!("couldn't connect to compute node at {host}:{port}: {err}");
|
||||
connection_error = Some(err);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Err(connection_error.unwrap_or_else(|| {
|
||||
io::Error::new(
|
||||
io::ErrorKind::Other,
|
||||
format!("bad compute config: {:?}", self.0),
|
||||
)
|
||||
}))
|
||||
}
|
||||
}
|
||||
|
||||
@@ -233,7 +272,7 @@ type RustlsStream = <MakeRustlsConnect as MakeTlsConnect<tokio::net::TcpStream>>
|
||||
pub(crate) struct PostgresConnection {
|
||||
/// Socket connected to a compute node.
|
||||
pub(crate) stream:
|
||||
postgres_client::maybe_tls_stream::MaybeTlsStream<tokio::net::TcpStream, RustlsStream>,
|
||||
tokio_postgres::maybe_tls_stream::MaybeTlsStream<tokio::net::TcpStream, RustlsStream>,
|
||||
/// PostgreSQL connection parameters.
|
||||
pub(crate) params: std::collections::HashMap<String, String>,
|
||||
/// Query cancellation token.
|
||||
@@ -335,9 +374,10 @@ impl ConnCfg {
|
||||
}
|
||||
|
||||
/// Retrieve `options` from a startup message, dropping all proxy-secific flags.
|
||||
fn filtered_options(options: &str) -> Option<String> {
|
||||
fn filtered_options(params: &StartupMessageParams) -> Option<String> {
|
||||
#[allow(unstable_name_collisions)]
|
||||
let options: String = StartupMessageParams::parse_options_raw(options)
|
||||
let options: String = params
|
||||
.options_raw()?
|
||||
.filter(|opt| parse_endpoint_param(opt).is_none() && neon_option(opt).is_none())
|
||||
.intersperse(" ") // TODO: use impl from std once it's stabilized
|
||||
.collect();
|
||||
@@ -414,24 +454,27 @@ mod tests {
|
||||
#[test]
|
||||
fn test_filtered_options() {
|
||||
// Empty options is unlikely to be useful anyway.
|
||||
let params = "";
|
||||
assert_eq!(filtered_options(params), None);
|
||||
let params = StartupMessageParams::new([("options", "")]);
|
||||
assert_eq!(filtered_options(¶ms), None);
|
||||
|
||||
// It's likely that clients will only use options to specify endpoint/project.
|
||||
let params = "project=foo";
|
||||
assert_eq!(filtered_options(params), None);
|
||||
let params = StartupMessageParams::new([("options", "project=foo")]);
|
||||
assert_eq!(filtered_options(¶ms), None);
|
||||
|
||||
// Same, because unescaped whitespaces are no-op.
|
||||
let params = " project=foo ";
|
||||
assert_eq!(filtered_options(params).as_deref(), None);
|
||||
let params = StartupMessageParams::new([("options", " project=foo ")]);
|
||||
assert_eq!(filtered_options(¶ms).as_deref(), None);
|
||||
|
||||
let params = r"\ project=foo \ ";
|
||||
assert_eq!(filtered_options(params).as_deref(), Some(r"\ \ "));
|
||||
let params = StartupMessageParams::new([("options", r"\ project=foo \ ")]);
|
||||
assert_eq!(filtered_options(¶ms).as_deref(), Some(r"\ \ "));
|
||||
|
||||
let params = "project = foo";
|
||||
assert_eq!(filtered_options(params).as_deref(), Some("project = foo"));
|
||||
let params = StartupMessageParams::new([("options", "project = foo")]);
|
||||
assert_eq!(filtered_options(¶ms).as_deref(), Some("project = foo"));
|
||||
|
||||
let params = "project = foo neon_endpoint_type:read_write neon_lsn:0/2 neon_proxy_params_compat:true";
|
||||
assert_eq!(filtered_options(params).as_deref(), Some("project = foo"));
|
||||
let params = StartupMessageParams::new([(
|
||||
"options",
|
||||
"project = foo neon_endpoint_type:read_write neon_lsn:0/2",
|
||||
)]);
|
||||
assert_eq!(filtered_options(¶ms).as_deref(), Some("project = foo"));
|
||||
}
|
||||
}
|
||||
|
||||
@@ -206,7 +206,6 @@ pub(crate) async fn handle_client<S: AsyncRead + AsyncWrite + Unpin>(
|
||||
let mut node = connect_to_compute(
|
||||
ctx,
|
||||
&TcpMechanism {
|
||||
params_compat: true,
|
||||
params: ¶ms,
|
||||
locks: &config.connect_compute_locks,
|
||||
},
|
||||
|
||||
@@ -5,6 +5,7 @@ use std::sync::Arc;
|
||||
|
||||
use futures::TryFutureExt;
|
||||
use thiserror::Error;
|
||||
use tokio_postgres::config::SslMode;
|
||||
use tokio_postgres::Client;
|
||||
use tracing::{error, info, info_span, warn, Instrument};
|
||||
|
||||
@@ -160,11 +161,11 @@ impl MockControlPlane {
|
||||
}
|
||||
|
||||
async fn do_wake_compute(&self) -> Result<NodeInfo, WakeComputeError> {
|
||||
let mut config = compute::ConnCfg::new(
|
||||
self.endpoint.host_str().unwrap_or("localhost").to_owned(),
|
||||
self.endpoint.port().unwrap_or(5432),
|
||||
);
|
||||
config.ssl_mode(postgres_client::config::SslMode::Disable);
|
||||
let mut config = compute::ConnCfg::new();
|
||||
config
|
||||
.host(self.endpoint.host_str().unwrap_or("localhost"))
|
||||
.port(self.endpoint.port().unwrap_or(5432))
|
||||
.ssl_mode(SslMode::Disable);
|
||||
|
||||
let node = NodeInfo {
|
||||
config,
|
||||
|
||||
@@ -6,8 +6,8 @@ use std::time::Duration;
|
||||
use ::http::header::AUTHORIZATION;
|
||||
use ::http::HeaderName;
|
||||
use futures::TryFutureExt;
|
||||
use postgres_client::config::SslMode;
|
||||
use tokio::time::Instant;
|
||||
use tokio_postgres::config::SslMode;
|
||||
use tracing::{debug, info, info_span, warn, Instrument};
|
||||
|
||||
use super::super::messages::{ControlPlaneErrorMessage, GetRoleSecret, WakeCompute};
|
||||
@@ -241,8 +241,8 @@ impl NeonControlPlaneClient {
|
||||
// Don't set anything but host and port! This config will be cached.
|
||||
// We'll set username and such later using the startup message.
|
||||
// TODO: add more type safety (in progress).
|
||||
let mut config = compute::ConnCfg::new(host.to_owned(), port);
|
||||
config.ssl_mode(SslMode::Disable); // TLS is not configured on compute nodes.
|
||||
let mut config = compute::ConnCfg::new();
|
||||
config.host(host).port(port).ssl_mode(SslMode::Disable); // TLS is not configured on compute nodes.
|
||||
|
||||
let node = NodeInfo {
|
||||
config,
|
||||
|
||||
@@ -84,7 +84,7 @@ pub(crate) trait ReportableError: fmt::Display + Send + 'static {
|
||||
fn get_error_kind(&self) -> ErrorKind;
|
||||
}
|
||||
|
||||
impl ReportableError for postgres_client::error::Error {
|
||||
impl ReportableError for tokio_postgres::error::Error {
|
||||
fn get_error_kind(&self) -> ErrorKind {
|
||||
if self.as_db_error().is_some() {
|
||||
ErrorKind::Postgres
|
||||
|
||||
@@ -1,10 +1,10 @@
|
||||
use std::convert::TryFrom;
|
||||
use std::sync::Arc;
|
||||
|
||||
use postgres_client::tls::MakeTlsConnect;
|
||||
use rustls::pki_types::ServerName;
|
||||
use rustls::ClientConfig;
|
||||
use tokio::io::{AsyncRead, AsyncWrite};
|
||||
use tokio_postgres::tls::MakeTlsConnect;
|
||||
|
||||
mod private {
|
||||
use std::future::Future;
|
||||
@@ -12,9 +12,9 @@ mod private {
|
||||
use std::pin::Pin;
|
||||
use std::task::{Context, Poll};
|
||||
|
||||
use postgres_client::tls::{ChannelBinding, TlsConnect};
|
||||
use rustls::pki_types::ServerName;
|
||||
use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
|
||||
use tokio_postgres::tls::{ChannelBinding, TlsConnect};
|
||||
use tokio_rustls::client::TlsStream;
|
||||
use tokio_rustls::TlsConnector;
|
||||
|
||||
@@ -59,7 +59,7 @@ mod private {
|
||||
|
||||
pub struct RustlsStream<S>(TlsStream<S>);
|
||||
|
||||
impl<S> postgres_client::tls::TlsStream for RustlsStream<S>
|
||||
impl<S> tokio_postgres::tls::TlsStream for RustlsStream<S>
|
||||
where
|
||||
S: AsyncRead + AsyncWrite + Unpin,
|
||||
{
|
||||
|
||||
@@ -66,8 +66,6 @@ pub(crate) trait ComputeConnectBackend {
|
||||
}
|
||||
|
||||
pub(crate) struct TcpMechanism<'a> {
|
||||
pub(crate) params_compat: bool,
|
||||
|
||||
/// KV-dictionary with PostgreSQL connection params.
|
||||
pub(crate) params: &'a StartupMessageParams,
|
||||
|
||||
@@ -88,13 +86,13 @@ impl ConnectMechanism for TcpMechanism<'_> {
|
||||
node_info: &control_plane::CachedNodeInfo,
|
||||
timeout: time::Duration,
|
||||
) -> Result<PostgresConnection, Self::Error> {
|
||||
let host = node_info.config.get_host();
|
||||
let host = node_info.config.get_host()?;
|
||||
let permit = self.locks.get_permit(&host).await?;
|
||||
permit.release_result(node_info.connect(ctx, timeout).await)
|
||||
}
|
||||
|
||||
fn update_connect_config(&self, config: &mut compute::ConnCfg) {
|
||||
config.set_startup_params(self.params, self.params_compat);
|
||||
config.set_startup_params(self.params);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -338,17 +338,9 @@ pub(crate) async fn handle_client<S: AsyncRead + AsyncWrite + Unpin>(
|
||||
}
|
||||
};
|
||||
|
||||
let params_compat = match &user_info {
|
||||
auth::Backend::ControlPlane(_, info) => {
|
||||
info.info.options.get(NeonOptions::PARAMS_COMPAT).is_some()
|
||||
}
|
||||
auth::Backend::Local(_) => false,
|
||||
};
|
||||
|
||||
let mut node = connect_to_compute(
|
||||
ctx,
|
||||
&TcpMechanism {
|
||||
params_compat,
|
||||
params: ¶ms,
|
||||
locks: &config.connect_compute_locks,
|
||||
},
|
||||
@@ -417,47 +409,19 @@ pub(crate) async fn prepare_client_connection<P>(
|
||||
pub(crate) struct NeonOptions(Vec<(SmolStr, SmolStr)>);
|
||||
|
||||
impl NeonOptions {
|
||||
// proxy options:
|
||||
|
||||
/// `PARAMS_COMPAT` allows opting in to forwarding all startup parameters from client to compute.
|
||||
const PARAMS_COMPAT: &str = "proxy_params_compat";
|
||||
|
||||
// cplane options:
|
||||
|
||||
/// `LSN` allows provisioning an ephemeral compute with time-travel to the provided LSN.
|
||||
const LSN: &str = "lsn";
|
||||
|
||||
/// `ENDPOINT_TYPE` allows configuring an ephemeral compute to be read_only or read_write.
|
||||
const ENDPOINT_TYPE: &str = "endpoint_type";
|
||||
|
||||
pub(crate) fn parse_params(params: &StartupMessageParams) -> Self {
|
||||
params
|
||||
.options_raw()
|
||||
.map(Self::parse_from_iter)
|
||||
.unwrap_or_default()
|
||||
}
|
||||
|
||||
pub(crate) fn parse_options_raw(options: &str) -> Self {
|
||||
Self::parse_from_iter(StartupMessageParams::parse_options_raw(options))
|
||||
}
|
||||
|
||||
pub(crate) fn get(&self, key: &str) -> Option<SmolStr> {
|
||||
self.0
|
||||
.iter()
|
||||
.find_map(|(k, v)| (k == key).then_some(v))
|
||||
.cloned()
|
||||
}
|
||||
|
||||
pub(crate) fn is_ephemeral(&self) -> bool {
|
||||
self.0.iter().any(|(k, _)| match &**k {
|
||||
// This is not a cplane option, we know it does not create ephemeral computes.
|
||||
Self::PARAMS_COMPAT => false,
|
||||
Self::LSN => true,
|
||||
Self::ENDPOINT_TYPE => true,
|
||||
// err on the side of caution. any cplane options we don't know about
|
||||
// might lead to ephemeral computes.
|
||||
_ => true,
|
||||
})
|
||||
// Currently, neon endpoint options are all reserved for ephemeral endpoints.
|
||||
!self.0.is_empty()
|
||||
}
|
||||
|
||||
fn parse_from_iter<'a>(options: impl Iterator<Item = &'a str>) -> Self {
|
||||
|
||||
@@ -31,9 +31,9 @@ impl CouldRetry for io::Error {
|
||||
}
|
||||
}
|
||||
|
||||
impl CouldRetry for postgres_client::error::DbError {
|
||||
impl CouldRetry for tokio_postgres::error::DbError {
|
||||
fn could_retry(&self) -> bool {
|
||||
use postgres_client::error::SqlState;
|
||||
use tokio_postgres::error::SqlState;
|
||||
matches!(
|
||||
self.code(),
|
||||
&SqlState::CONNECTION_FAILURE
|
||||
@@ -43,9 +43,9 @@ impl CouldRetry for postgres_client::error::DbError {
|
||||
)
|
||||
}
|
||||
}
|
||||
impl ShouldRetryWakeCompute for postgres_client::error::DbError {
|
||||
impl ShouldRetryWakeCompute for tokio_postgres::error::DbError {
|
||||
fn should_retry_wake_compute(&self) -> bool {
|
||||
use postgres_client::error::SqlState;
|
||||
use tokio_postgres::error::SqlState;
|
||||
// Here are errors that happens after the user successfully authenticated to the database.
|
||||
// TODO: there are pgbouncer errors that should be retried, but they are not listed here.
|
||||
!matches!(
|
||||
@@ -61,21 +61,21 @@ impl ShouldRetryWakeCompute for postgres_client::error::DbError {
|
||||
}
|
||||
}
|
||||
|
||||
impl CouldRetry for postgres_client::Error {
|
||||
impl CouldRetry for tokio_postgres::Error {
|
||||
fn could_retry(&self) -> bool {
|
||||
if let Some(io_err) = self.source().and_then(|x| x.downcast_ref()) {
|
||||
io::Error::could_retry(io_err)
|
||||
} else if let Some(db_err) = self.source().and_then(|x| x.downcast_ref()) {
|
||||
postgres_client::error::DbError::could_retry(db_err)
|
||||
tokio_postgres::error::DbError::could_retry(db_err)
|
||||
} else {
|
||||
false
|
||||
}
|
||||
}
|
||||
}
|
||||
impl ShouldRetryWakeCompute for postgres_client::Error {
|
||||
impl ShouldRetryWakeCompute for tokio_postgres::Error {
|
||||
fn should_retry_wake_compute(&self) -> bool {
|
||||
if let Some(db_err) = self.source().and_then(|x| x.downcast_ref()) {
|
||||
postgres_client::error::DbError::should_retry_wake_compute(db_err)
|
||||
tokio_postgres::error::DbError::should_retry_wake_compute(db_err)
|
||||
} else {
|
||||
// likely an IO error. Possible the compute has shutdown and the
|
||||
// cache is stale.
|
||||
|
||||
@@ -8,9 +8,9 @@ use std::fmt::Debug;
|
||||
|
||||
use bytes::{Bytes, BytesMut};
|
||||
use futures::{SinkExt, StreamExt};
|
||||
use postgres_client::tls::TlsConnect;
|
||||
use postgres_protocol::message::frontend;
|
||||
use tokio::io::{AsyncReadExt, DuplexStream};
|
||||
use tokio_postgres::tls::TlsConnect;
|
||||
use tokio_util::codec::{Decoder, Encoder};
|
||||
|
||||
use super::*;
|
||||
@@ -55,13 +55,7 @@ async fn proxy_mitm(
|
||||
|
||||
// give the end_server the startup parameters
|
||||
let mut buf = BytesMut::new();
|
||||
frontend::startup_message(
|
||||
&postgres_protocol::message::frontend::StartupMessageParams {
|
||||
params: startup.params.into(),
|
||||
},
|
||||
&mut buf,
|
||||
)
|
||||
.unwrap();
|
||||
frontend::startup_message(startup.iter(), &mut buf).unwrap();
|
||||
end_server.send(buf.freeze()).await.unwrap();
|
||||
|
||||
// proxy messages between end_client and end_server
|
||||
@@ -164,8 +158,8 @@ async fn scram_auth_disable_channel_binding() -> anyhow::Result<()> {
|
||||
Scram::new("password").await?,
|
||||
));
|
||||
|
||||
let _client_err = postgres_client::Config::new("test".to_owned(), 5432)
|
||||
.channel_binding(postgres_client::config::ChannelBinding::Disable)
|
||||
let _client_err = tokio_postgres::Config::new()
|
||||
.channel_binding(tokio_postgres::config::ChannelBinding::Disable)
|
||||
.user("user")
|
||||
.dbname("db")
|
||||
.password("password")
|
||||
@@ -181,7 +175,7 @@ async fn scram_auth_disable_channel_binding() -> anyhow::Result<()> {
|
||||
async fn scram_auth_prefer_channel_binding() -> anyhow::Result<()> {
|
||||
connect_failure(
|
||||
Intercept::None,
|
||||
postgres_client::config::ChannelBinding::Prefer,
|
||||
tokio_postgres::config::ChannelBinding::Prefer,
|
||||
)
|
||||
.await
|
||||
}
|
||||
@@ -191,7 +185,7 @@ async fn scram_auth_prefer_channel_binding() -> anyhow::Result<()> {
|
||||
async fn scram_auth_prefer_channel_binding_intercept() -> anyhow::Result<()> {
|
||||
connect_failure(
|
||||
Intercept::Methods,
|
||||
postgres_client::config::ChannelBinding::Prefer,
|
||||
tokio_postgres::config::ChannelBinding::Prefer,
|
||||
)
|
||||
.await
|
||||
}
|
||||
@@ -201,7 +195,7 @@ async fn scram_auth_prefer_channel_binding_intercept() -> anyhow::Result<()> {
|
||||
async fn scram_auth_prefer_channel_binding_intercept_response() -> anyhow::Result<()> {
|
||||
connect_failure(
|
||||
Intercept::SASLResponse,
|
||||
postgres_client::config::ChannelBinding::Prefer,
|
||||
tokio_postgres::config::ChannelBinding::Prefer,
|
||||
)
|
||||
.await
|
||||
}
|
||||
@@ -211,7 +205,7 @@ async fn scram_auth_prefer_channel_binding_intercept_response() -> anyhow::Resul
|
||||
async fn scram_auth_require_channel_binding() -> anyhow::Result<()> {
|
||||
connect_failure(
|
||||
Intercept::None,
|
||||
postgres_client::config::ChannelBinding::Require,
|
||||
tokio_postgres::config::ChannelBinding::Require,
|
||||
)
|
||||
.await
|
||||
}
|
||||
@@ -221,7 +215,7 @@ async fn scram_auth_require_channel_binding() -> anyhow::Result<()> {
|
||||
async fn scram_auth_require_channel_binding_intercept() -> anyhow::Result<()> {
|
||||
connect_failure(
|
||||
Intercept::Methods,
|
||||
postgres_client::config::ChannelBinding::Require,
|
||||
tokio_postgres::config::ChannelBinding::Require,
|
||||
)
|
||||
.await
|
||||
}
|
||||
@@ -231,14 +225,14 @@ async fn scram_auth_require_channel_binding_intercept() -> anyhow::Result<()> {
|
||||
async fn scram_auth_require_channel_binding_intercept_response() -> anyhow::Result<()> {
|
||||
connect_failure(
|
||||
Intercept::SASLResponse,
|
||||
postgres_client::config::ChannelBinding::Require,
|
||||
tokio_postgres::config::ChannelBinding::Require,
|
||||
)
|
||||
.await
|
||||
}
|
||||
|
||||
async fn connect_failure(
|
||||
intercept: Intercept,
|
||||
channel_binding: postgres_client::config::ChannelBinding,
|
||||
channel_binding: tokio_postgres::config::ChannelBinding,
|
||||
) -> anyhow::Result<()> {
|
||||
let (server, client, client_config, server_config) = proxy_mitm(intercept).await;
|
||||
let proxy = tokio::spawn(dummy_proxy(
|
||||
@@ -247,7 +241,7 @@ async fn connect_failure(
|
||||
Scram::new("password").await?,
|
||||
));
|
||||
|
||||
let _client_err = postgres_client::Config::new("test".to_owned(), 5432)
|
||||
let _client_err = tokio_postgres::Config::new()
|
||||
.channel_binding(channel_binding)
|
||||
.user("user")
|
||||
.dbname("db")
|
||||
|
||||
@@ -7,13 +7,13 @@ use std::time::Duration;
|
||||
use anyhow::{bail, Context};
|
||||
use async_trait::async_trait;
|
||||
use http::StatusCode;
|
||||
use postgres_client::config::SslMode;
|
||||
use postgres_client::tls::{MakeTlsConnect, NoTls};
|
||||
use retry::{retry_after, ShouldRetryWakeCompute};
|
||||
use rstest::rstest;
|
||||
use rustls::crypto::ring;
|
||||
use rustls::pki_types;
|
||||
use tokio::io::DuplexStream;
|
||||
use tokio_postgres::config::SslMode;
|
||||
use tokio_postgres::tls::{MakeTlsConnect, NoTls};
|
||||
|
||||
use super::connect_compute::ConnectMechanism;
|
||||
use super::retry::CouldRetry;
|
||||
@@ -204,7 +204,7 @@ async fn handshake_tls_is_enforced_by_proxy() -> anyhow::Result<()> {
|
||||
let (_, server_config) = generate_tls_config("generic-project-name.localhost", "localhost")?;
|
||||
let proxy = tokio::spawn(dummy_proxy(client, Some(server_config), NoAuth));
|
||||
|
||||
let client_err = postgres_client::Config::new("test".to_owned(), 5432)
|
||||
let client_err = tokio_postgres::Config::new()
|
||||
.user("john_doe")
|
||||
.dbname("earth")
|
||||
.ssl_mode(SslMode::Disable)
|
||||
@@ -233,7 +233,7 @@ async fn handshake_tls() -> anyhow::Result<()> {
|
||||
generate_tls_config("generic-project-name.localhost", "localhost")?;
|
||||
let proxy = tokio::spawn(dummy_proxy(client, Some(server_config), NoAuth));
|
||||
|
||||
let _conn = postgres_client::Config::new("test".to_owned(), 5432)
|
||||
let _conn = tokio_postgres::Config::new()
|
||||
.user("john_doe")
|
||||
.dbname("earth")
|
||||
.ssl_mode(SslMode::Require)
|
||||
@@ -249,10 +249,10 @@ async fn handshake_raw() -> anyhow::Result<()> {
|
||||
|
||||
let proxy = tokio::spawn(dummy_proxy(client, None, NoAuth));
|
||||
|
||||
let _conn = postgres_client::Config::new("test".to_owned(), 5432)
|
||||
let _conn = tokio_postgres::Config::new()
|
||||
.user("john_doe")
|
||||
.dbname("earth")
|
||||
.set_param("options", "project=generic-project-name")
|
||||
.options("project=generic-project-name")
|
||||
.ssl_mode(SslMode::Prefer)
|
||||
.connect_raw(server, NoTls)
|
||||
.await?;
|
||||
@@ -296,8 +296,8 @@ async fn scram_auth_good(#[case] password: &str) -> anyhow::Result<()> {
|
||||
Scram::new(password).await?,
|
||||
));
|
||||
|
||||
let _conn = postgres_client::Config::new("test".to_owned(), 5432)
|
||||
.channel_binding(postgres_client::config::ChannelBinding::Require)
|
||||
let _conn = tokio_postgres::Config::new()
|
||||
.channel_binding(tokio_postgres::config::ChannelBinding::Require)
|
||||
.user("user")
|
||||
.dbname("db")
|
||||
.password(password)
|
||||
@@ -320,8 +320,8 @@ async fn scram_auth_disable_channel_binding() -> anyhow::Result<()> {
|
||||
Scram::new("password").await?,
|
||||
));
|
||||
|
||||
let _conn = postgres_client::Config::new("test".to_owned(), 5432)
|
||||
.channel_binding(postgres_client::config::ChannelBinding::Disable)
|
||||
let _conn = tokio_postgres::Config::new()
|
||||
.channel_binding(tokio_postgres::config::ChannelBinding::Disable)
|
||||
.user("user")
|
||||
.dbname("db")
|
||||
.password("password")
|
||||
@@ -348,7 +348,7 @@ async fn scram_auth_mock() -> anyhow::Result<()> {
|
||||
.map(char::from)
|
||||
.collect();
|
||||
|
||||
let _client_err = postgres_client::Config::new("test".to_owned(), 5432)
|
||||
let _client_err = tokio_postgres::Config::new()
|
||||
.user("user")
|
||||
.dbname("db")
|
||||
.password(&password) // no password will match the mocked secret
|
||||
@@ -546,7 +546,7 @@ impl TestControlPlaneClient for TestConnectMechanism {
|
||||
|
||||
fn helper_create_cached_node_info(cache: &'static NodeInfoCache) -> CachedNodeInfo {
|
||||
let node = NodeInfo {
|
||||
config: compute::ConnCfg::new("test".to_owned(), 5432),
|
||||
config: compute::ConnCfg::new(),
|
||||
aux: MetricsAuxInfo {
|
||||
endpoint_id: (&EndpointId::from("endpoint")).into(),
|
||||
project_id: (&ProjectId::from("project")).into(),
|
||||
|
||||
@@ -37,9 +37,9 @@ use crate::types::{EndpointId, Host, LOCAL_PROXY_SUFFIX};
|
||||
|
||||
pub(crate) struct PoolingBackend {
|
||||
pub(crate) http_conn_pool: Arc<GlobalConnPool<Send, HttpConnPool<Send>>>,
|
||||
pub(crate) local_pool: Arc<LocalConnPool<postgres_client::Client>>,
|
||||
pub(crate) local_pool: Arc<LocalConnPool<tokio_postgres::Client>>,
|
||||
pub(crate) pool:
|
||||
Arc<GlobalConnPool<postgres_client::Client, EndpointConnPool<postgres_client::Client>>>,
|
||||
Arc<GlobalConnPool<tokio_postgres::Client, EndpointConnPool<tokio_postgres::Client>>>,
|
||||
|
||||
pub(crate) config: &'static ProxyConfig,
|
||||
pub(crate) auth_backend: &'static crate::auth::Backend<'static, ()>,
|
||||
@@ -170,7 +170,7 @@ impl PoolingBackend {
|
||||
conn_info: ConnInfo,
|
||||
keys: ComputeCredentials,
|
||||
force_new: bool,
|
||||
) -> Result<Client<postgres_client::Client>, HttpConnError> {
|
||||
) -> Result<Client<tokio_postgres::Client>, HttpConnError> {
|
||||
let maybe_client = if force_new {
|
||||
debug!("pool: pool is disabled");
|
||||
None
|
||||
@@ -256,7 +256,7 @@ impl PoolingBackend {
|
||||
&self,
|
||||
ctx: &RequestContext,
|
||||
conn_info: ConnInfo,
|
||||
) -> Result<Client<postgres_client::Client>, HttpConnError> {
|
||||
) -> Result<Client<tokio_postgres::Client>, HttpConnError> {
|
||||
if let Some(client) = self.local_pool.get(ctx, &conn_info)? {
|
||||
return Ok(client);
|
||||
}
|
||||
@@ -309,16 +309,13 @@ impl PoolingBackend {
|
||||
.config
|
||||
.user(&conn_info.user_info.user)
|
||||
.dbname(&conn_info.dbname)
|
||||
.set_param(
|
||||
"options",
|
||||
&format!(
|
||||
"-c pg_session_jwt.jwk={}",
|
||||
serde_json::to_string(&jwk).expect("serializing jwk to json should not fail")
|
||||
),
|
||||
);
|
||||
.options(&format!(
|
||||
"-c pg_session_jwt.jwk={}",
|
||||
serde_json::to_string(&jwk).expect("serializing jwk to json should not fail")
|
||||
));
|
||||
|
||||
let pause = ctx.latency_timer_pause(crate::metrics::Waiting::Compute);
|
||||
let (client, connection) = config.connect(postgres_client::NoTls).await?;
|
||||
let (client, connection) = config.connect(tokio_postgres::NoTls).await?;
|
||||
drop(pause);
|
||||
|
||||
let pid = client.get_process_id();
|
||||
@@ -363,7 +360,7 @@ pub(crate) enum HttpConnError {
|
||||
#[error("pooled connection closed at inconsistent state")]
|
||||
ConnectionClosedAbruptly(#[from] tokio::sync::watch::error::SendError<uuid::Uuid>),
|
||||
#[error("could not connection to postgres in compute")]
|
||||
PostgresConnectionError(#[from] postgres_client::Error),
|
||||
PostgresConnectionError(#[from] tokio_postgres::Error),
|
||||
#[error("could not connection to local-proxy in compute")]
|
||||
LocalProxyConnectionError(#[from] LocalProxyConnError),
|
||||
#[error("could not parse JWT payload")]
|
||||
@@ -482,7 +479,7 @@ impl ShouldRetryWakeCompute for LocalProxyConnError {
|
||||
}
|
||||
|
||||
struct TokioMechanism {
|
||||
pool: Arc<GlobalConnPool<postgres_client::Client, EndpointConnPool<postgres_client::Client>>>,
|
||||
pool: Arc<GlobalConnPool<tokio_postgres::Client, EndpointConnPool<tokio_postgres::Client>>>,
|
||||
conn_info: ConnInfo,
|
||||
conn_id: uuid::Uuid,
|
||||
|
||||
@@ -492,7 +489,7 @@ struct TokioMechanism {
|
||||
|
||||
#[async_trait]
|
||||
impl ConnectMechanism for TokioMechanism {
|
||||
type Connection = Client<postgres_client::Client>;
|
||||
type Connection = Client<tokio_postgres::Client>;
|
||||
type ConnectError = HttpConnError;
|
||||
type Error = HttpConnError;
|
||||
|
||||
@@ -502,7 +499,7 @@ impl ConnectMechanism for TokioMechanism {
|
||||
node_info: &CachedNodeInfo,
|
||||
timeout: Duration,
|
||||
) -> Result<Self::Connection, Self::ConnectError> {
|
||||
let host = node_info.config.get_host();
|
||||
let host = node_info.config.get_host()?;
|
||||
let permit = self.locks.get_permit(&host).await?;
|
||||
|
||||
let mut config = (*node_info.config).clone();
|
||||
@@ -512,7 +509,7 @@ impl ConnectMechanism for TokioMechanism {
|
||||
.connect_timeout(timeout);
|
||||
|
||||
let pause = ctx.latency_timer_pause(crate::metrics::Waiting::Compute);
|
||||
let res = config.connect(postgres_client::NoTls).await;
|
||||
let res = config.connect(tokio_postgres::NoTls).await;
|
||||
drop(pause);
|
||||
let (client, connection) = permit.release_result(res)?;
|
||||
|
||||
@@ -552,12 +549,16 @@ impl ConnectMechanism for HyperMechanism {
|
||||
node_info: &CachedNodeInfo,
|
||||
timeout: Duration,
|
||||
) -> Result<Self::Connection, Self::ConnectError> {
|
||||
let host = node_info.config.get_host();
|
||||
let host = node_info.config.get_host()?;
|
||||
let permit = self.locks.get_permit(&host).await?;
|
||||
|
||||
let pause = ctx.latency_timer_pause(crate::metrics::Waiting::Compute);
|
||||
|
||||
let port = node_info.config.get_port();
|
||||
let port = *node_info.config.get_ports().first().ok_or_else(|| {
|
||||
HttpConnError::WakeCompute(WakeComputeError::BadComputeAddress(
|
||||
"local-proxy port missing on compute address".into(),
|
||||
))
|
||||
})?;
|
||||
let res = connect_http2(&host, port, timeout).await;
|
||||
drop(pause);
|
||||
let (client, connection) = permit.release_result(res)?;
|
||||
|
||||
@@ -5,11 +5,11 @@ use std::task::{ready, Poll};
|
||||
|
||||
use futures::future::poll_fn;
|
||||
use futures::Future;
|
||||
use postgres_client::tls::NoTlsStream;
|
||||
use postgres_client::AsyncMessage;
|
||||
use smallvec::SmallVec;
|
||||
use tokio::net::TcpStream;
|
||||
use tokio::time::Instant;
|
||||
use tokio_postgres::tls::NoTlsStream;
|
||||
use tokio_postgres::AsyncMessage;
|
||||
use tokio_util::sync::CancellationToken;
|
||||
use tracing::{error, info, info_span, warn, Instrument};
|
||||
#[cfg(test)]
|
||||
@@ -58,7 +58,7 @@ pub(crate) fn poll_client<C: ClientInnerExt>(
|
||||
ctx: &RequestContext,
|
||||
conn_info: ConnInfo,
|
||||
client: C,
|
||||
mut connection: postgres_client::Connection<TcpStream, NoTlsStream>,
|
||||
mut connection: tokio_postgres::Connection<TcpStream, NoTlsStream>,
|
||||
conn_id: uuid::Uuid,
|
||||
aux: MetricsAuxInfo,
|
||||
) -> Client<C> {
|
||||
|
||||
@@ -7,8 +7,8 @@ use std::time::Duration;
|
||||
|
||||
use dashmap::DashMap;
|
||||
use parking_lot::RwLock;
|
||||
use postgres_client::ReadyForQueryStatus;
|
||||
use rand::Rng;
|
||||
use tokio_postgres::ReadyForQueryStatus;
|
||||
use tracing::{debug, info, Span};
|
||||
|
||||
use super::backend::HttpConnError;
|
||||
@@ -683,7 +683,7 @@ pub(crate) trait ClientInnerExt: Sync + Send + 'static {
|
||||
fn get_process_id(&self) -> i32;
|
||||
}
|
||||
|
||||
impl ClientInnerExt for postgres_client::Client {
|
||||
impl ClientInnerExt for tokio_postgres::Client {
|
||||
fn is_closed(&self) -> bool {
|
||||
self.is_closed()
|
||||
}
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
use postgres_client::types::{Kind, Type};
|
||||
use postgres_client::Row;
|
||||
use serde_json::{Map, Value};
|
||||
use tokio_postgres::types::{Kind, Type};
|
||||
use tokio_postgres::Row;
|
||||
|
||||
//
|
||||
// Convert json non-string types to strings, so that they can be passed to Postgres
|
||||
@@ -61,7 +61,7 @@ fn json_array_to_pg_array(value: &Value) -> Option<String> {
|
||||
#[derive(Debug, thiserror::Error)]
|
||||
pub(crate) enum JsonConversionError {
|
||||
#[error("internal error compute returned invalid data: {0}")]
|
||||
AsTextError(postgres_client::Error),
|
||||
AsTextError(tokio_postgres::Error),
|
||||
#[error("parse int error: {0}")]
|
||||
ParseIntError(#[from] std::num::ParseIntError),
|
||||
#[error("parse float error: {0}")]
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user